Skip to content

API Reference: Attacks

Base class

auditml.attacks.base.BaseAttack

Bases: ABC

Abstract base for all privacy attacks.

Parameters:

Name Type Description Default
target_model Module

The trained model being attacked. Must be in eval() mode.

required
config

Optional AuditML configuration. When provided (YAML / CLI workflow) the attack reads its params from it. When None each subclass uses its own explicit keyword parameters instead.

None
device device | str

Torch device the model lives on.

'cpu'
Source code in src/auditml/attacks/base.py
class BaseAttack(ABC):
    """Abstract base for all privacy attacks.

    Parameters
    ----------
    target_model:
        The trained model being attacked. Must be in ``eval()`` mode.
    config:
        Optional AuditML configuration. When provided (YAML / CLI
        workflow) the attack reads its params from it. When ``None``
        each subclass uses its own explicit keyword parameters instead.
    device:
        Torch device the model lives on.
    """

    attack_name: str = "base"  # overridden by each subclass

    def __init__(
        self,
        target_model: nn.Module,
        config=None,
        device: torch.device | str = "cpu",
    ) -> None:
        self.target_model = target_model
        self.target_model.eval()  # always eval mode for attacks
        self.config = config
        self.device = torch.device(device)
        self.result: AttackResult | None = None

    # ------------------------------------------------------------------
    # Abstract methods — each concrete attack MUST implement these
    # ------------------------------------------------------------------

    @abstractmethod
    def run(
        self,
        member_loader: DataLoader,
        nonmember_loader: DataLoader,
    ) -> AttackResult:
        """Execute the attack.

        Parameters
        ----------
        member_loader:
            DataLoader over samples the target model WAS trained on.
        nonmember_loader:
            DataLoader over samples the target model was NOT trained on.

        Returns
        -------
        AttackResult
            Predictions, ground truth, and confidence scores.
        """
        ...

    # ------------------------------------------------------------------
    # Evaluation — shared across all attacks
    # ------------------------------------------------------------------

    def evaluate(self) -> dict[str, float]:
        """Compute standard metrics from the most recent ``run()``.

        Returns
        -------
        dict
            Keys: accuracy, precision, recall, f1, auc_roc, auc_pr,
            tpr_at_1fpr, tpr_at_01fpr.

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None:
            raise RuntimeError("Call run() before evaluate().")
        return self._compute_metrics(
            self.result.predictions,
            self.result.ground_truth,
            self.result.confidence_scores,
        )

    # ------------------------------------------------------------------
    # Shared utility methods — used by multiple attacks
    # ------------------------------------------------------------------

    @torch.no_grad()
    def get_model_outputs(
        self, loader: DataLoader,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Run the target model on every sample in *loader*.

        Returns
        -------
        (probabilities, logits, labels)
            - probabilities: ``(N, num_classes)`` softmax output
            - logits: ``(N, num_classes)`` raw model output
            - labels: ``(N,)`` true class labels from the dataset
        """
        all_probs: list[np.ndarray] = []
        all_logits: list[np.ndarray] = []
        all_labels: list[np.ndarray] = []

        for inputs, targets in loader:
            inputs = inputs.to(self.device)
            logits = self.target_model(inputs)
            probs = F.softmax(logits, dim=1)

            all_logits.append(logits.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
            all_labels.append(targets.numpy())

        return (
            np.concatenate(all_probs),
            np.concatenate(all_logits),
            np.concatenate(all_labels),
        )

    @torch.no_grad()
    def get_loss_values(self, loader: DataLoader) -> np.ndarray:
        """Compute **per-sample** cross-entropy loss for every sample.

        This is critical for threshold-based MIA: training samples
        typically have lower loss because the model has seen them before.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — one loss value per sample.
        """
        criterion = nn.CrossEntropyLoss(reduction="none")  # per-sample
        all_losses: list[np.ndarray] = []

        for inputs, targets in loader:
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            logits = self.target_model(inputs)
            losses = criterion(logits, targets)
            all_losses.append(losses.cpu().numpy())

        return np.concatenate(all_losses)

    # ------------------------------------------------------------------
    # Metrics computation
    # ------------------------------------------------------------------

    @staticmethod
    def _compute_metrics(
        predictions: np.ndarray,
        ground_truth: np.ndarray,
        confidence_scores: np.ndarray,
    ) -> dict[str, float]:
        """Compute a comprehensive set of binary classification metrics.

        Parameters
        ----------
        predictions:
            Binary array (0/1) — the attack's prediction.
        ground_truth:
            Binary array (0/1) — the true membership label.
        confidence_scores:
            Continuous score — higher means "more likely member".

        Returns
        -------
        dict with keys:
            accuracy, precision, recall, f1, auc_roc, auc_pr,
            tpr_at_1fpr, tpr_at_01fpr
        """
        metrics: dict[str, float] = {
            "accuracy": float(accuracy_score(ground_truth, predictions)),
            "precision": float(precision_score(ground_truth, predictions, zero_division=0)),
            "recall": float(recall_score(ground_truth, predictions, zero_division=0)),
            "f1": float(f1_score(ground_truth, predictions, zero_division=0)),
        }

        # ROC-based metrics (need continuous scores)
        if len(np.unique(ground_truth)) == 2:
            fpr, tpr, _ = roc_curve(ground_truth, confidence_scores)
            metrics["auc_roc"] = float(roc_auc_score(ground_truth, confidence_scores))

            # TPR at specific FPR thresholds — realistic adversary constraints
            metrics["tpr_at_1fpr"] = float(np.interp(0.01, fpr, tpr))
            metrics["tpr_at_01fpr"] = float(np.interp(0.001, fpr, tpr))

            # Precision-Recall AUC
            prec_arr, rec_arr, _ = precision_recall_curve(ground_truth, confidence_scores)
            metrics["auc_pr"] = float(auc(rec_arr, prec_arr))
        else:
            # Edge case: if all samples have the same label, AUC is undefined
            metrics["auc_roc"] = 0.0
            metrics["tpr_at_1fpr"] = 0.0
            metrics["tpr_at_01fpr"] = 0.0
            metrics["auc_pr"] = 0.0

        return metrics

run(member_loader: DataLoader, nonmember_loader: DataLoader) -> AttackResult abstractmethod

Execute the attack.

Parameters:

Name Type Description Default
member_loader DataLoader

DataLoader over samples the target model WAS trained on.

required
nonmember_loader DataLoader

DataLoader over samples the target model was NOT trained on.

required

Returns:

Type Description
AttackResult

Predictions, ground truth, and confidence scores.

Source code in src/auditml/attacks/base.py
@abstractmethod
def run(
    self,
    member_loader: DataLoader,
    nonmember_loader: DataLoader,
) -> AttackResult:
    """Execute the attack.

    Parameters
    ----------
    member_loader:
        DataLoader over samples the target model WAS trained on.
    nonmember_loader:
        DataLoader over samples the target model was NOT trained on.

    Returns
    -------
    AttackResult
        Predictions, ground truth, and confidence scores.
    """
    ...

evaluate() -> dict[str, float]

Compute standard metrics from the most recent run().

Returns:

Type Description
dict

Keys: accuracy, precision, recall, f1, auc_roc, auc_pr, tpr_at_1fpr, tpr_at_01fpr.

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/base.py
def evaluate(self) -> dict[str, float]:
    """Compute standard metrics from the most recent ``run()``.

    Returns
    -------
    dict
        Keys: accuracy, precision, recall, f1, auc_roc, auc_pr,
        tpr_at_1fpr, tpr_at_01fpr.

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None:
        raise RuntimeError("Call run() before evaluate().")
    return self._compute_metrics(
        self.result.predictions,
        self.result.ground_truth,
        self.result.confidence_scores,
    )

get_model_outputs(loader: DataLoader) -> tuple[np.ndarray, np.ndarray, np.ndarray]

Run the target model on every sample in loader.

Returns:

Type Description
(probabilities, logits, labels)
  • probabilities: (N, num_classes) softmax output
  • logits: (N, num_classes) raw model output
  • labels: (N,) true class labels from the dataset
Source code in src/auditml/attacks/base.py
@torch.no_grad()
def get_model_outputs(
    self, loader: DataLoader,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Run the target model on every sample in *loader*.

    Returns
    -------
    (probabilities, logits, labels)
        - probabilities: ``(N, num_classes)`` softmax output
        - logits: ``(N, num_classes)`` raw model output
        - labels: ``(N,)`` true class labels from the dataset
    """
    all_probs: list[np.ndarray] = []
    all_logits: list[np.ndarray] = []
    all_labels: list[np.ndarray] = []

    for inputs, targets in loader:
        inputs = inputs.to(self.device)
        logits = self.target_model(inputs)
        probs = F.softmax(logits, dim=1)

        all_logits.append(logits.cpu().numpy())
        all_probs.append(probs.cpu().numpy())
        all_labels.append(targets.numpy())

    return (
        np.concatenate(all_probs),
        np.concatenate(all_logits),
        np.concatenate(all_labels),
    )

get_loss_values(loader: DataLoader) -> np.ndarray

Compute per-sample cross-entropy loss for every sample.

This is critical for threshold-based MIA: training samples typically have lower loss because the model has seen them before.

Returns:

Type Description
ndarray

Shape (N,) — one loss value per sample.

Source code in src/auditml/attacks/base.py
@torch.no_grad()
def get_loss_values(self, loader: DataLoader) -> np.ndarray:
    """Compute **per-sample** cross-entropy loss for every sample.

    This is critical for threshold-based MIA: training samples
    typically have lower loss because the model has seen them before.

    Returns
    -------
    np.ndarray
        Shape ``(N,)`` — one loss value per sample.
    """
    criterion = nn.CrossEntropyLoss(reduction="none")  # per-sample
    all_losses: list[np.ndarray] = []

    for inputs, targets in loader:
        inputs = inputs.to(self.device)
        targets = targets.to(self.device)
        logits = self.target_model(inputs)
        losses = criterion(logits, targets)
        all_losses.append(losses.cpu().numpy())

    return np.concatenate(all_losses)

Attack results

auditml.attacks.results.AttackResult dataclass

Container for the outputs of a single attack run.

Attributes:

Name Type Description
predictions ndarray

Binary array — the attack's guess for each sample. For membership inference: 1 = predicted member, 0 = predicted non-member. Length equals len(ground_truth).

ground_truth ndarray

Binary array — the true label for each sample. For membership inference: 1 = actual member, 0 = actual non-member.

confidence_scores ndarray

Continuous score per sample indicating how confident the attack is. Higher = more confident the sample is a member (for MIA) or more confident in the predicted attribute (for attribute inference). Used for ROC curves and threshold-independent evaluation.

attack_name str

Human-readable name, e.g. "mia_threshold" or "model_inversion".

metadata dict[str, Any]

Free-form dict for attack-specific extras (e.g. reconstructed images for model inversion, per-class breakdowns, etc.).

Source code in src/auditml/attacks/results.py
@dataclass
class AttackResult:
    """Container for the outputs of a single attack run.

    Attributes
    ----------
    predictions:
        Binary array — the attack's guess for each sample.
        For membership inference: 1 = predicted member, 0 = predicted
        non-member. Length equals ``len(ground_truth)``.
    ground_truth:
        Binary array — the true label for each sample.
        For membership inference: 1 = actual member, 0 = actual
        non-member.
    confidence_scores:
        Continuous score per sample indicating how confident the attack
        is.  Higher = more confident the sample is a member (for MIA)
        or more confident in the predicted attribute (for attribute
        inference). Used for ROC curves and threshold-independent
        evaluation.
    attack_name:
        Human-readable name, e.g. ``"mia_threshold"`` or
        ``"model_inversion"``.
    metadata:
        Free-form dict for attack-specific extras (e.g. reconstructed
        images for model inversion, per-class breakdowns, etc.).
    """

    predictions: np.ndarray
    ground_truth: np.ndarray
    confidence_scores: np.ndarray
    attack_name: str = ""
    metadata: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self) -> None:
        """Validate array lengths match."""
        n = len(self.ground_truth)
        if len(self.predictions) != n:
            raise ValueError(
                f"predictions length ({len(self.predictions)}) != "
                f"ground_truth length ({n})"
            )
        if len(self.confidence_scores) != n:
            raise ValueError(
                f"confidence_scores length ({len(self.confidence_scores)}) != "
                f"ground_truth length ({n})"
            )

__post_init__() -> None

Validate array lengths match.

Source code in src/auditml/attacks/results.py
def __post_init__(self) -> None:
    """Validate array lengths match."""
    n = len(self.ground_truth)
    if len(self.predictions) != n:
        raise ValueError(
            f"predictions length ({len(self.predictions)}) != "
            f"ground_truth length ({n})"
        )
    if len(self.confidence_scores) != n:
        raise ValueError(
            f"confidence_scores length ({len(self.confidence_scores)}) != "
            f"ground_truth length ({n})"
        )

Threshold MIA

auditml.attacks.mia_threshold.ThresholdMIA

Bases: BaseAttack

Threshold-based Membership Inference Attack.

Supports three signal metrics:

  • "loss" — per-sample cross-entropy loss (lower → more likely member)
  • "confidence" — max softmax probability (higher → more likely member)
  • "entropy" — prediction entropy (lower → more likely member)

Parameters:

Name Type Description Default
target_model

The trained model to attack.

required
config

AuditML config. Reads config.attack_params.mia_threshold for metric and percentile settings.

None
device

Torch device.

'cpu'
Source code in src/auditml/attacks/mia_threshold.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class ThresholdMIA(BaseAttack):
    """Threshold-based Membership Inference Attack.

    Supports three signal metrics:

    - ``"loss"`` — per-sample cross-entropy loss (lower → more likely member)
    - ``"confidence"`` — max softmax probability (higher → more likely member)
    - ``"entropy"`` — prediction entropy (lower → more likely member)

    Parameters
    ----------
    target_model:
        The trained model to attack.
    config:
        AuditML config. Reads ``config.attack_params.mia_threshold`` for
        ``metric`` and ``percentile`` settings.
    device:
        Torch device.
    """

    attack_name = "mia_threshold"

    def __init__(
        self,
        target_model,
        config=None,
        device="cpu",
        *,
        metric: str | None = None,
        percentile: float | None = None,
    ):
        super().__init__(target_model, config, device)
        # Explicit params take priority; fall back to config; then hardcoded defaults.
        cfg_params = config.attack_params.mia_threshold if config is not None else None
        self.metric = metric or (cfg_params.metric if cfg_params else "loss")
        self.percentile = percentile or (cfg_params.percentile if cfg_params else 50)

        # Intermediate values stored after run() for analysis/plotting
        self.member_scores: np.ndarray | None = None
        self.nonmember_scores: np.ndarray | None = None
        self.threshold: float | None = None
        # Class labels for per-class evaluation (stored during run())
        self.member_labels: np.ndarray | None = None
        self.nonmember_labels: np.ndarray | None = None

    # ------------------------------------------------------------------
    # Main attack logic
    # ------------------------------------------------------------------

    def run(
        self,
        member_loader: DataLoader,
        nonmember_loader: DataLoader,
    ) -> AttackResult:
        """Execute the threshold MIA.

        Steps:
        1. Compute the chosen metric (loss/confidence/entropy) for every
           member and non-member sample.
        2. Find the optimal threshold that maximises accuracy.
        3. Classify each sample as member or non-member using that
           threshold.

        Parameters
        ----------
        member_loader:
            DataLoader for samples the model WAS trained on.
        nonmember_loader:
            DataLoader for samples the model was NOT trained on.

        Returns
        -------
        AttackResult
        """
        # Step 1: Compute signal metric for both groups
        self.member_scores = self._compute_signal(member_loader)
        self.nonmember_scores = self._compute_signal(nonmember_loader)

        # Store class labels for per-class evaluation
        self.member_labels = self._extract_labels(member_loader)
        self.nonmember_labels = self._extract_labels(nonmember_loader)

        # Combine into single arrays
        all_scores = np.concatenate([self.member_scores, self.nonmember_scores])
        ground_truth = np.concatenate([
            np.ones(len(self.member_scores)),    # 1 = member
            np.zeros(len(self.nonmember_scores)),  # 0 = non-member
        ])

        # Step 2: Find optimal threshold
        self.threshold = self._find_optimal_threshold(all_scores, ground_truth)

        # Step 3: Classify using threshold
        predictions = self._apply_threshold(all_scores, self.threshold)

        # Confidence scores: how "member-like" each sample is
        # We normalise so that higher = more likely member, regardless of metric
        confidence_scores = self._scores_to_confidence(all_scores)

        self.result = AttackResult(
            predictions=predictions,
            ground_truth=ground_truth,
            confidence_scores=confidence_scores,
            attack_name=self.attack_name,
            metadata={
                "metric": self.metric,
                "threshold": self.threshold,
                "member_mean": float(np.mean(self.member_scores)),
                "nonmember_mean": float(np.mean(self.nonmember_scores)),
                "member_std": float(np.std(self.member_scores)),
                "nonmember_std": float(np.std(self.nonmember_scores)),
            },
        )
        return self.result

    # ------------------------------------------------------------------
    # Signal computation — the core of the attack
    # ------------------------------------------------------------------

    def _compute_signal(self, loader: DataLoader) -> np.ndarray:
        """Compute the chosen metric for every sample in *loader*.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — one score per sample.
        """
        if self.metric == "loss":
            return self._compute_loss_signal(loader)
        elif self.metric == "confidence":
            return self._compute_confidence_signal(loader)
        elif self.metric == "entropy":
            return self._compute_entropy_signal(loader)
        else:
            raise ValueError(
                f"Unknown metric {self.metric!r}. "
                f"Choose from: 'loss', 'confidence', 'entropy'"
            )

    def _compute_loss_signal(self, loader: DataLoader) -> np.ndarray:
        """Per-sample cross-entropy loss.

        Members typically have LOWER loss because the model was trained
        to minimise loss on them.
        """
        return self.get_loss_values(loader)

    def _compute_confidence_signal(self, loader: DataLoader) -> np.ndarray:
        """Max softmax probability (prediction confidence).

        Members typically have HIGHER confidence because the model has
        seen them before and is more certain about them.
        """
        probs, _, _ = self.get_model_outputs(loader)
        return np.max(probs, axis=1)

    def _compute_entropy_signal(self, loader: DataLoader) -> np.ndarray:
        """Prediction entropy: -sum(p * log(p)).

        Members typically have LOWER entropy (less uncertainty) because
        the model is more confident about them.
        """
        probs, _, _ = self.get_model_outputs(loader)
        # Clip to avoid log(0)
        probs_clipped = np.clip(probs, 1e-10, 1.0)
        entropy = -np.sum(probs_clipped * np.log(probs_clipped), axis=1)
        return entropy

    # ------------------------------------------------------------------
    # Threshold selection
    # ------------------------------------------------------------------

    def _find_optimal_threshold(
        self,
        scores: np.ndarray,
        ground_truth: np.ndarray,
    ) -> float:
        """Find the threshold that maximises attack accuracy.

        Tries every unique score value as a potential threshold and picks
        the one that correctly classifies the most samples.

        Parameters
        ----------
        scores:
            Signal values for all samples (members + non-members).
        ground_truth:
            Binary labels (1 = member, 0 = non-member).

        Returns
        -------
        float
            The optimal threshold value.
        """
        # Use Rust-accelerated threshold scan (falls back to NumPy if unavailable)
        best_threshold, _ = _rust_find_best_threshold(scores, ground_truth)
        return best_threshold

    def _apply_threshold(
        self, scores: np.ndarray, threshold: float,
    ) -> np.ndarray:
        """Classify samples as member/non-member using *threshold*.

        The direction depends on the metric:
        - loss: score < threshold → member (lower loss = member)
        - confidence: score > threshold → member (higher conf = member)
        - entropy: score < threshold → member (lower entropy = member)

        Returns
        -------
        np.ndarray
            Binary predictions (1 = member, 0 = non-member).
        """
        if self.metric == "confidence":
            # Higher confidence → more likely member
            return (scores >= threshold).astype(int)
        else:
            # Lower loss / lower entropy → more likely member
            return (scores <= threshold).astype(int)

    def _scores_to_confidence(self, scores: np.ndarray) -> np.ndarray:
        """Convert raw scores to confidence values where higher = more
        likely member.

        This normalisation is needed so that ROC curves and AUC work
        correctly regardless of which metric was used.
        """
        if self.metric == "confidence":
            # Already in the right direction: higher = more member-like
            return scores
        else:
            # Loss and entropy: lower = more member-like, so negate
            return -scores

    # ------------------------------------------------------------------
    # Label extraction
    # ------------------------------------------------------------------

    @staticmethod
    def _extract_labels(loader: DataLoader) -> np.ndarray:
        """Extract the class labels from a DataLoader.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — one integer class label per sample.
        """
        all_labels: list[np.ndarray] = []
        for _, targets in loader:
            all_labels.append(targets.numpy())
        return np.concatenate(all_labels)

    # ------------------------------------------------------------------
    # Per-class evaluation
    # ------------------------------------------------------------------

    def evaluate_per_class(self) -> dict[int, dict[str, float]]:
        """Compute evaluation metrics **separately for each class**.

        This reveals whether the attack works better on certain classes.
        For example, rare classes might be easier to identify as members
        because the model memorises them more.

        Returns
        -------
        dict[int, dict[str, float]]
            Mapping from class label → metric dictionary. Each inner dict
            has the same keys as ``evaluate()`` (accuracy, precision, …).

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None:
            raise RuntimeError("Call run() before evaluate_per_class().")

        all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
        unique_classes = np.unique(all_labels)

        per_class: dict[int, dict[str, float]] = {}
        for cls in unique_classes:
            mask = all_labels == cls
            # Need at least 2 samples AND both member/non-member in this class
            preds_cls = self.result.predictions[mask]
            gt_cls = self.result.ground_truth[mask]
            scores_cls = self.result.confidence_scores[mask]

            if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
                # Not enough data for meaningful per-class metrics
                per_class[int(cls)] = {
                    "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                    "precision": 0.0,
                    "recall": 0.0,
                    "f1": 0.0,
                    "auc_roc": 0.0,
                    "auc_pr": 0.0,
                    "tpr_at_1fpr": 0.0,
                    "tpr_at_01fpr": 0.0,
                    "n_samples": int(mask.sum()),
                }
                continue

            metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
            metrics["n_samples"] = int(mask.sum())
            per_class[int(cls)] = metrics

        return per_class

    # ------------------------------------------------------------------
    # Report generation
    # ------------------------------------------------------------------

    def generate_report(self, output_dir: str | Path) -> Path:
        """Generate a complete evaluation report with metrics and plots.

        Creates the following files in *output_dir*:

        - ``metrics.json`` — overall evaluation metrics
        - ``per_class_metrics.json`` — per-class breakdown
        - ``roc_curve.png`` — ROC curve plot
        - ``score_distributions.png`` — histogram of member vs non-member scores
        - ``per_class_accuracy.png`` — bar chart of per-class attack accuracy
        - ``summary.txt`` — human-readable text summary

        Parameters
        ----------
        output_dir:
            Directory where all report files are saved. Created if it
            doesn't exist.

        Returns
        -------
        Path
            The output directory.

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None:
            raise RuntimeError("Call run() before generate_report().")

        # Lazy import to avoid matplotlib overhead when not needed
        from auditml.attacks.visualization import (
            plot_per_class_metrics,
            plot_roc_curve,
            plot_score_distributions,
        )

        out = Path(output_dir)
        out.mkdir(parents=True, exist_ok=True)

        # 1. Overall metrics
        metrics = self.evaluate()
        with open(out / "metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)

        # 2. Per-class metrics
        per_class = self.evaluate_per_class()
        # JSON keys must be strings
        per_class_str = {str(k): v for k, v in per_class.items()}
        with open(out / "per_class_metrics.json", "w") as f:
            json.dump(per_class_str, f, indent=2)

        # 3. ROC curve
        plot_roc_curve(
            ground_truth=self.result.ground_truth,
            confidence_scores=self.result.confidence_scores,
            save_path=out / "roc_curve.png",
        )

        # 4. Score distributions histogram
        plot_score_distributions(
            member_scores=self.member_scores,
            nonmember_scores=self.nonmember_scores,
            metric_name=self.metric,
            threshold=self.threshold,
            save_path=out / "score_distributions.png",
        )

        # 5. Per-class accuracy bar chart
        plot_per_class_metrics(
            per_class_metrics=per_class,
            save_path=out / "per_class_accuracy.png",
        )

        # 6. Human-readable summary
        self._write_summary(out / "summary.txt", metrics, per_class)

        return out

    def _write_summary(
        self,
        path: Path,
        metrics: dict[str, float],
        per_class: dict[int, dict[str, float]],
    ) -> None:
        """Write a human-readable text summary of the attack results."""
        lines = [
            "=" * 60,
            "AuditML — Threshold MIA Report",
            "=" * 60,
            "",
            f"Metric used:     {self.metric}",
            f"Threshold:       {self.threshold:.6f}",
            f"Total samples:   {len(self.result.predictions)}",
            f"  Members:       {int(self.result.ground_truth.sum())}",
            f"  Non-members:   {int((1 - self.result.ground_truth).sum())}",
            "",
            "--- Overall Metrics ---",
        ]
        for key, val in metrics.items():
            lines.append(f"  {key:<20s}: {val:.4f}")

        lines.append("")
        lines.append("--- Per-Class Breakdown ---")
        for cls in sorted(per_class.keys()):
            m = per_class[cls]
            lines.append(
                f"  Class {cls:>3d}:  acc={m['accuracy']:.3f}  "
                f"auc={m['auc_roc']:.3f}  n={m['n_samples']}"
            )

        lines.append("")
        lines.append("--- Metadata ---")
        for key, val in self.result.metadata.items():
            lines.append(f"  {key}: {val}")

        lines.append("")
        path.write_text("\n".join(lines))

run(member_loader: DataLoader, nonmember_loader: DataLoader) -> AttackResult

Execute the threshold MIA.

Steps: 1. Compute the chosen metric (loss/confidence/entropy) for every member and non-member sample. 2. Find the optimal threshold that maximises accuracy. 3. Classify each sample as member or non-member using that threshold.

Parameters:

Name Type Description Default
member_loader DataLoader

DataLoader for samples the model WAS trained on.

required
nonmember_loader DataLoader

DataLoader for samples the model was NOT trained on.

required

Returns:

Type Description
AttackResult
Source code in src/auditml/attacks/mia_threshold.py
def run(
    self,
    member_loader: DataLoader,
    nonmember_loader: DataLoader,
) -> AttackResult:
    """Execute the threshold MIA.

    Steps:
    1. Compute the chosen metric (loss/confidence/entropy) for every
       member and non-member sample.
    2. Find the optimal threshold that maximises accuracy.
    3. Classify each sample as member or non-member using that
       threshold.

    Parameters
    ----------
    member_loader:
        DataLoader for samples the model WAS trained on.
    nonmember_loader:
        DataLoader for samples the model was NOT trained on.

    Returns
    -------
    AttackResult
    """
    # Step 1: Compute signal metric for both groups
    self.member_scores = self._compute_signal(member_loader)
    self.nonmember_scores = self._compute_signal(nonmember_loader)

    # Store class labels for per-class evaluation
    self.member_labels = self._extract_labels(member_loader)
    self.nonmember_labels = self._extract_labels(nonmember_loader)

    # Combine into single arrays
    all_scores = np.concatenate([self.member_scores, self.nonmember_scores])
    ground_truth = np.concatenate([
        np.ones(len(self.member_scores)),    # 1 = member
        np.zeros(len(self.nonmember_scores)),  # 0 = non-member
    ])

    # Step 2: Find optimal threshold
    self.threshold = self._find_optimal_threshold(all_scores, ground_truth)

    # Step 3: Classify using threshold
    predictions = self._apply_threshold(all_scores, self.threshold)

    # Confidence scores: how "member-like" each sample is
    # We normalise so that higher = more likely member, regardless of metric
    confidence_scores = self._scores_to_confidence(all_scores)

    self.result = AttackResult(
        predictions=predictions,
        ground_truth=ground_truth,
        confidence_scores=confidence_scores,
        attack_name=self.attack_name,
        metadata={
            "metric": self.metric,
            "threshold": self.threshold,
            "member_mean": float(np.mean(self.member_scores)),
            "nonmember_mean": float(np.mean(self.nonmember_scores)),
            "member_std": float(np.std(self.member_scores)),
            "nonmember_std": float(np.std(self.nonmember_scores)),
        },
    )
    return self.result

evaluate_per_class() -> dict[int, dict[str, float]]

Compute evaluation metrics separately for each class.

This reveals whether the attack works better on certain classes. For example, rare classes might be easier to identify as members because the model memorises them more.

Returns:

Type Description
dict[int, dict[str, float]]

Mapping from class label → metric dictionary. Each inner dict has the same keys as evaluate() (accuracy, precision, …).

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/mia_threshold.py
def evaluate_per_class(self) -> dict[int, dict[str, float]]:
    """Compute evaluation metrics **separately for each class**.

    This reveals whether the attack works better on certain classes.
    For example, rare classes might be easier to identify as members
    because the model memorises them more.

    Returns
    -------
    dict[int, dict[str, float]]
        Mapping from class label → metric dictionary. Each inner dict
        has the same keys as ``evaluate()`` (accuracy, precision, …).

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None:
        raise RuntimeError("Call run() before evaluate_per_class().")

    all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
    unique_classes = np.unique(all_labels)

    per_class: dict[int, dict[str, float]] = {}
    for cls in unique_classes:
        mask = all_labels == cls
        # Need at least 2 samples AND both member/non-member in this class
        preds_cls = self.result.predictions[mask]
        gt_cls = self.result.ground_truth[mask]
        scores_cls = self.result.confidence_scores[mask]

        if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
            # Not enough data for meaningful per-class metrics
            per_class[int(cls)] = {
                "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0,
                "auc_roc": 0.0,
                "auc_pr": 0.0,
                "tpr_at_1fpr": 0.0,
                "tpr_at_01fpr": 0.0,
                "n_samples": int(mask.sum()),
            }
            continue

        metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
        metrics["n_samples"] = int(mask.sum())
        per_class[int(cls)] = metrics

    return per_class

generate_report(output_dir: str | Path) -> Path

Generate a complete evaluation report with metrics and plots.

Creates the following files in output_dir:

  • metrics.json — overall evaluation metrics
  • per_class_metrics.json — per-class breakdown
  • roc_curve.png — ROC curve plot
  • score_distributions.png — histogram of member vs non-member scores
  • per_class_accuracy.png — bar chart of per-class attack accuracy
  • summary.txt — human-readable text summary

Parameters:

Name Type Description Default
output_dir str | Path

Directory where all report files are saved. Created if it doesn't exist.

required

Returns:

Type Description
Path

The output directory.

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/mia_threshold.py
def generate_report(self, output_dir: str | Path) -> Path:
    """Generate a complete evaluation report with metrics and plots.

    Creates the following files in *output_dir*:

    - ``metrics.json`` — overall evaluation metrics
    - ``per_class_metrics.json`` — per-class breakdown
    - ``roc_curve.png`` — ROC curve plot
    - ``score_distributions.png`` — histogram of member vs non-member scores
    - ``per_class_accuracy.png`` — bar chart of per-class attack accuracy
    - ``summary.txt`` — human-readable text summary

    Parameters
    ----------
    output_dir:
        Directory where all report files are saved. Created if it
        doesn't exist.

    Returns
    -------
    Path
        The output directory.

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None:
        raise RuntimeError("Call run() before generate_report().")

    # Lazy import to avoid matplotlib overhead when not needed
    from auditml.attacks.visualization import (
        plot_per_class_metrics,
        plot_roc_curve,
        plot_score_distributions,
    )

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    # 1. Overall metrics
    metrics = self.evaluate()
    with open(out / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    # 2. Per-class metrics
    per_class = self.evaluate_per_class()
    # JSON keys must be strings
    per_class_str = {str(k): v for k, v in per_class.items()}
    with open(out / "per_class_metrics.json", "w") as f:
        json.dump(per_class_str, f, indent=2)

    # 3. ROC curve
    plot_roc_curve(
        ground_truth=self.result.ground_truth,
        confidence_scores=self.result.confidence_scores,
        save_path=out / "roc_curve.png",
    )

    # 4. Score distributions histogram
    plot_score_distributions(
        member_scores=self.member_scores,
        nonmember_scores=self.nonmember_scores,
        metric_name=self.metric,
        threshold=self.threshold,
        save_path=out / "score_distributions.png",
    )

    # 5. Per-class accuracy bar chart
    plot_per_class_metrics(
        per_class_metrics=per_class,
        save_path=out / "per_class_accuracy.png",
    )

    # 6. Human-readable summary
    self._write_summary(out / "summary.txt", metrics, per_class)

    return out

Shadow Model MIA

auditml.attacks.mia_shadow.ShadowMIA

Bases: BaseAttack

Shadow-model Membership Inference Attack.

Workflow executed by run():

  1. Train shadow models — each on a different random split of the same dataset distribution. The number and epochs come from config.attack_params.mia_shadow.
  2. Collect attack data — for each shadow model, gather its softmax outputs on its own members (label 1) and non-members (label 0).
  3. Train attack MLP — a small binary classifier on the collected (probability_vector, membership_label) dataset.
  4. Attack the target — run the target model on the supplied member and non-member loaders, then classify each sample with the trained attack model.

Parameters:

Name Type Description Default
target_model Module

The trained model being audited.

required
config

Full AuditML configuration.

None
device device | str

Torch device.

'cpu'
shadow_dataset Dataset | None

The dataset from which shadow model training data is drawn. This should be the same distribution as the target's training data (e.g. the full CIFAR-10 training set). If None, shadow models must be provided manually via shadow_models.

None
shadow_models list[tuple[Module, DataLoader, DataLoader]] | None

Pre-trained shadow models. If provided, skips the training step. Each entry is (model, member_loader, nonmember_loader).

None
Source code in src/auditml/attacks/mia_shadow.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
class ShadowMIA(BaseAttack):
    """Shadow-model Membership Inference Attack.

    Workflow executed by ``run()``:

    1. **Train shadow models** — each on a different random split of the
       same dataset distribution. The number and epochs come from
       ``config.attack_params.mia_shadow``.
    2. **Collect attack data** — for each shadow model, gather its softmax
       outputs on its own members (label 1) and non-members (label 0).
    3. **Train attack MLP** — a small binary classifier on the collected
       (probability_vector, membership_label) dataset.
    4. **Attack the target** — run the target model on the supplied member
       and non-member loaders, then classify each sample with the trained
       attack model.

    Parameters
    ----------
    target_model:
        The trained model being audited.
    config:
        Full AuditML configuration.
    device:
        Torch device.
    shadow_dataset:
        The dataset from which shadow model training data is drawn.
        This should be the **same distribution** as the target's training
        data (e.g. the full CIFAR-10 training set). If ``None``, shadow
        models must be provided manually via ``shadow_models``.
    shadow_models:
        Pre-trained shadow models. If provided, skips the training step.
        Each entry is ``(model, member_loader, nonmember_loader)``.
    """

    attack_name = "mia_shadow"

    def __init__(
        self,
        target_model: nn.Module,
        config=None,
        device: torch.device | str = "cpu",
        shadow_dataset: Dataset | None = None,
        shadow_models: list[tuple[nn.Module, DataLoader, DataLoader]] | None = None,
        shadow_model_fn=None,
        *,
        num_shadows: int | None = None,
        shadow_epochs: int | None = None,
        num_classes: int | None = None,
        batch_size: int | None = None,
        member_ratio: float | None = None,
        seed: int | None = None,
        optimizer: str | None = None,
        learning_rate: float | None = None,
        weight_decay: float | None = None,
    ) -> None:
        super().__init__(target_model, config, device)
        self.shadow_dataset = shadow_dataset
        self.shadow_models = shadow_models
        self.shadow_model_fn = shadow_model_fn

        # Explicit params take priority; fall back to config; then hardcoded defaults.
        cfg_p = config.attack_params.mia_shadow if config is not None else None
        self.num_shadows = num_shadows or (cfg_p.num_shadow_models if cfg_p else 4)
        self.shadow_epochs = shadow_epochs or (cfg_p.shadow_epochs if cfg_p else 10)
        self.num_classes = num_classes or (config.model.num_classes if config is not None else 10)
        self._batch_size = batch_size or (config.training.batch_size if config is not None else 64)
        self._member_ratio = (
            member_ratio or (config.data.train_ratio if config is not None else 0.5)
        )
        self._seed = seed or (config.training.seed if config is not None else 42)
        self._optimizer = optimizer or (config.training.optimizer if config is not None else "adam")
        self._lr = learning_rate or (config.training.learning_rate if config is not None else 0.001)
        self._weight_decay = (
            weight_decay or (config.training.weight_decay if config is not None else 1e-4)
        )

        # Will be populated during run()
        self.attack_model: AttackMLP | None = None
        self.trained_shadows: list[nn.Module] = []
        # Stored during run() for evaluation and visualization
        self.member_confidence: np.ndarray | None = None
        self.nonmember_confidence: np.ndarray | None = None
        self.member_labels: np.ndarray | None = None
        self.nonmember_labels: np.ndarray | None = None

    # ------------------------------------------------------------------
    # Main attack logic
    # ------------------------------------------------------------------

    def run(
        self,
        member_loader: DataLoader,
        nonmember_loader: DataLoader,
    ) -> AttackResult:
        """Execute the full shadow-model MIA pipeline.

        Steps:
            1. Train shadow models (or use pre-trained ones)
            2. Collect (output, membership) pairs from all shadows
            3. Train the attack MLP on shadow data
            4. Use the attack MLP to classify target model outputs
        """
        # Step 1: Get shadow models with their data
        shadow_data = self._get_shadow_data()

        # Step 2: Collect attack training data from shadow models
        attack_features, attack_labels = self._collect_attack_data(shadow_data)
        logger.info(
            "Collected %d attack training samples (%d members, %d non-members)",
            len(attack_labels),
            int(attack_labels.sum()),
            int((1 - attack_labels).sum()),
        )

        # Step 3: Train the attack model
        self.attack_model = self._train_attack_model(attack_features, attack_labels)

        # Step 4: Attack the target model
        member_probs, _, member_true_labels = self.get_model_outputs(member_loader)
        nonmember_probs, _, nonmember_true_labels = self.get_model_outputs(nonmember_loader)

        # Store class labels for per-class evaluation
        self.member_labels = member_true_labels
        self.nonmember_labels = nonmember_true_labels

        # Build ground truth: 1 = member, 0 = non-member
        ground_truth = np.concatenate([
            np.ones(len(member_probs)),
            np.zeros(len(nonmember_probs)),
        ])

        # Get attack model predictions
        all_probs = np.concatenate([member_probs, nonmember_probs])
        confidence_scores = self._attack_predict(all_probs)
        predictions = (confidence_scores >= 0.5).astype(np.int32)

        # Store per-group confidence for visualization
        self.member_confidence = confidence_scores[:len(member_probs)]
        self.nonmember_confidence = confidence_scores[len(member_probs):]

        self.result = AttackResult(
            predictions=predictions,
            ground_truth=ground_truth,
            confidence_scores=confidence_scores,
            attack_name=self.attack_name,
            metadata={
                "num_shadow_models": self.num_shadows,
                "shadow_epochs": self.shadow_epochs,
                "attack_train_samples": len(attack_labels),
                "member_mean_confidence": float(confidence_scores[:len(member_probs)].mean()),
                "nonmember_mean_confidence": float(confidence_scores[len(member_probs):].mean()),
            },
        )
        return self.result

    # ------------------------------------------------------------------
    # Step 1: Shadow model training
    # ------------------------------------------------------------------

    def _get_shadow_data(
        self,
    ) -> list[tuple[nn.Module, DataLoader, DataLoader]]:
        """Return shadow models with their member/non-member loaders.

        If ``shadow_models`` were passed at init, use them directly.
        Otherwise, train new shadow models from ``shadow_dataset``.
        """
        if self.shadow_models is not None:
            return self.shadow_models

        if self.shadow_dataset is None:
            raise ValueError(
                "Either shadow_dataset or shadow_models must be provided. "
                "Pass the full training dataset as shadow_dataset so we can "
                "create independent splits for shadow model training."
            )

        return self._train_shadow_models()

    def _train_shadow_models(
        self,
    ) -> list[tuple[nn.Module, DataLoader, DataLoader]]:
        """Train shadow models from scratch.

        Each shadow model gets its own random member/non-member split of
        ``shadow_dataset``. This mirrors how the target model was trained,
        so the shadow models learn similar decision boundaries.
        """
        batch_size = self._batch_size
        splits = get_shadow_data_splits(
            self.shadow_dataset,
            n_shadows=self.num_shadows,
            member_ratio=self._member_ratio,
            seed=self._seed,
        )

        results: list[tuple[nn.Module, DataLoader, DataLoader]] = []

        from tqdm import tqdm
        for i, (member_set, nonmember_set, _, _) in enumerate(
            tqdm(splits, desc="Training shadow models", unit="model", ncols=70)
        ):
            logger.info("Training shadow model %d/%d ...", i + 1, self.num_shadows)

            # Create a fresh model — user factory > config registry > MLP fallback
            if self.shadow_model_fn is not None:
                shadow = self.shadow_model_fn().to(self.device)
            elif self.config is not None:
                shadow = get_model(
                    arch=self.config.model.arch,
                    dataset=self.config.data.dataset.value,
                ).to(self.device)
            else:
                shadow = self._build_fallback_shadow(member_set).to(self.device)

            # Create data loaders
            train_loader = DataLoader(
                member_set, batch_size=batch_size, shuffle=True,
            )
            val_loader = DataLoader(
                nonmember_set, batch_size=batch_size, shuffle=False,
            )

            # Train the shadow model
            optimizer = build_optimizer(
                shadow,
                name=self._optimizer,
                lr=self._lr,
                weight_decay=self._weight_decay,
            )
            trainer = Trainer(
                model=shadow,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer=optimizer,
                device=self.device,
            )
            trainer.train(epochs=self.shadow_epochs, patience=0)

            shadow.eval()
            self.trained_shadows.append(shadow)

            # Create evaluation loaders (no shuffle, for consistent ordering)
            member_eval = DataLoader(
                member_set, batch_size=batch_size, shuffle=False,
            )
            nonmember_eval = DataLoader(
                nonmember_set, batch_size=batch_size, shuffle=False,
            )
            results.append((shadow, member_eval, nonmember_eval))

        return results

    def _build_fallback_shadow(self, dataset) -> nn.Module:
        """Build a simple MLP shadow model when no factory or config is given.

        Infers the input dimension from the first sample in *dataset* and
        creates a 3-layer MLP that matches ``self.num_classes`` outputs.
        """
        sample_x = dataset[0][0]
        input_dim = int(sample_x.numel())
        num_out = self.num_classes

        hidden = max(64, min(512, input_dim))

        class _FallbackMLP(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net = nn.Sequential(
                    nn.Flatten(),
                    nn.Linear(input_dim, hidden),
                    nn.ReLU(),
                    nn.Linear(hidden, hidden // 2),
                    nn.ReLU(),
                    nn.Linear(hidden // 2, num_out),
                )

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.net(x)

        logger.info(
            "shadow_model_fn not provided — using fallback MLP "
            "(%d%d%d%d)",
            input_dim, hidden, hidden // 2, num_out,
        )
        return _FallbackMLP()

    # ------------------------------------------------------------------
    # Step 2: Collect attack training data
    # ------------------------------------------------------------------

    def _collect_attack_data(
        self,
        shadow_data: list[tuple[nn.Module, DataLoader, DataLoader]],
    ) -> tuple[np.ndarray, np.ndarray]:
        """Gather (softmax_output, membership_label) from all shadows.

        For each shadow model:
        - Run it on its member data → label these outputs as 1 (member)
        - Run it on its non-member data → label these outputs as 0

        Returns
        -------
        (features, labels)
            features: shape ``(total_samples, num_classes)``
            labels: shape ``(total_samples,)`` — 0 or 1
        """
        all_features: list[np.ndarray] = []
        all_labels: list[np.ndarray] = []

        for shadow_model, member_loader, nonmember_loader in shadow_data:
            # Temporarily swap target_model to extract outputs from shadow
            original_model = self.target_model
            self.target_model = shadow_model
            self.target_model.eval()

            member_probs, _, _ = self.get_model_outputs(member_loader)
            nonmember_probs, _, _ = self.get_model_outputs(nonmember_loader)

            # Restore the real target
            self.target_model = original_model

            all_features.append(member_probs)
            all_features.append(nonmember_probs)
            all_labels.append(np.ones(len(member_probs)))
            all_labels.append(np.zeros(len(nonmember_probs)))

        return np.concatenate(all_features), np.concatenate(all_labels)

    # ------------------------------------------------------------------
    # Step 3: Train attack classifier
    # ------------------------------------------------------------------

    def _train_attack_model(
        self,
        features: np.ndarray,
        labels: np.ndarray,
        epochs: int = 50,
        lr: float = 0.001,
    ) -> AttackMLP:
        """Train a binary MLP to predict membership from softmax outputs.

        Parameters
        ----------
        features:
            Shape ``(N, num_classes)`` — softmax probability vectors.
        labels:
            Shape ``(N,)`` — 1 for member, 0 for non-member.
        epochs:
            Training epochs for the attack model.
        lr:
            Learning rate.

        Returns
        -------
        AttackMLP
            The trained attack classifier.
        """
        input_dim = features.shape[1]
        model = AttackMLP(input_dim=input_dim).to(self.device)

        dataset = TensorDataset(
            torch.tensor(features, dtype=torch.float32),
            torch.tensor(labels, dtype=torch.float32),
        )
        loader = DataLoader(dataset, batch_size=128, shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.BCEWithLogitsLoss()

        model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for batch_x, batch_y in loader:
                batch_x = batch_x.to(self.device)
                batch_y = batch_y.to(self.device)

                optimizer.zero_grad()
                logits = model(batch_x).squeeze(-1)
                loss = criterion(logits, batch_y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        model.eval()
        logger.info("Attack model trained for %d epochs (final loss: %.4f)",
                     epochs, total_loss / max(len(loader), 1))
        return model

    # ------------------------------------------------------------------
    # Step 4: Predict with attack model
    # ------------------------------------------------------------------

    def _attack_predict(self, probs: np.ndarray) -> np.ndarray:
        """Use the trained attack MLP to predict membership probability.

        Parameters
        ----------
        probs:
            Shape ``(N, num_classes)`` — softmax outputs from the target.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — probability of being a member, in [0, 1].
        """
        self.attack_model.eval()
        x = torch.tensor(probs, dtype=torch.float32).to(self.device)
        with torch.no_grad():
            scores = self.attack_model.predict_proba(x)
        return scores.cpu().numpy()

    # ------------------------------------------------------------------
    # Per-class evaluation
    # ------------------------------------------------------------------

    def evaluate_per_class(self) -> dict[int, dict[str, float]]:
        """Compute evaluation metrics **separately for each class**.

        Groups all samples by their original class label and computes the
        full metric suite for each class. This reveals which classes are
        most vulnerable to the shadow model attack.

        Returns
        -------
        dict[int, dict[str, float]]
            Mapping from class label to metric dictionary.

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None:
            raise RuntimeError("Call run() before evaluate_per_class().")

        all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
        unique_classes = np.unique(all_labels)

        per_class: dict[int, dict[str, float]] = {}
        for cls in unique_classes:
            mask = all_labels == cls
            preds_cls = self.result.predictions[mask]
            gt_cls = self.result.ground_truth[mask]
            scores_cls = self.result.confidence_scores[mask]

            if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
                per_class[int(cls)] = {
                    "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                    "precision": 0.0,
                    "recall": 0.0,
                    "f1": 0.0,
                    "auc_roc": 0.0,
                    "auc_pr": 0.0,
                    "tpr_at_1fpr": 0.0,
                    "tpr_at_01fpr": 0.0,
                    "n_samples": int(mask.sum()),
                }
                continue

            metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
            metrics["n_samples"] = int(mask.sum())
            per_class[int(cls)] = metrics

        return per_class

    # ------------------------------------------------------------------
    # Report generation
    # ------------------------------------------------------------------

    def generate_report(self, output_dir: str | Path) -> Path:
        """Generate a complete evaluation report with metrics and plots.

        Creates the following files in *output_dir*:

        - ``metrics.json`` — overall evaluation metrics
        - ``per_class_metrics.json`` — per-class breakdown
        - ``roc_curve.png`` — ROC curve plot
        - ``confidence_distributions.png`` — histogram of attack confidence
        - ``per_class_accuracy.png`` — bar chart of per-class accuracy
        - ``summary.txt`` — human-readable text summary

        Parameters
        ----------
        output_dir:
            Directory where all report files are saved.

        Returns
        -------
        Path
            The output directory.
        """
        if self.result is None:
            raise RuntimeError("Call run() before generate_report().")

        from auditml.attacks.visualization import (
            plot_per_class_metrics,
            plot_roc_curve,
            plot_score_distributions,
        )

        out = Path(output_dir)
        out.mkdir(parents=True, exist_ok=True)

        # 1. Overall metrics
        metrics = self.evaluate()
        with open(out / "metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)

        # 2. Per-class metrics
        per_class = self.evaluate_per_class()
        per_class_str = {str(k): v for k, v in per_class.items()}
        with open(out / "per_class_metrics.json", "w") as f:
            json.dump(per_class_str, f, indent=2)

        # 3. ROC curve
        plot_roc_curve(
            ground_truth=self.result.ground_truth,
            confidence_scores=self.result.confidence_scores,
            title="ROC Curve — Shadow Model MIA",
            save_path=out / "roc_curve.png",
        )

        # 4. Confidence distribution histogram
        plot_score_distributions(
            member_scores=self.member_confidence,
            nonmember_scores=self.nonmember_confidence,
            metric_name="attack confidence",
            save_path=out / "confidence_distributions.png",
            title="Attack Confidence Distribution — Shadow Model MIA",
        )

        # 5. Per-class accuracy bar chart
        plot_per_class_metrics(
            per_class_metrics=per_class,
            save_path=out / "per_class_accuracy.png",
        )

        # 6. Summary text
        self._write_summary(out / "summary.txt", metrics, per_class)

        return out

    def _write_summary(
        self,
        path: Path,
        metrics: dict[str, float],
        per_class: dict[int, dict[str, float]],
    ) -> None:
        """Write a human-readable text summary of the attack results."""
        lines = [
            "=" * 60,
            "AuditML — Shadow Model MIA Report",
            "=" * 60,
            "",
            f"Shadow models:   {self.num_shadows}",
            f"Shadow epochs:   {self.shadow_epochs}",
            f"Total samples:   {len(self.result.predictions)}",
            f"  Members:       {int(self.result.ground_truth.sum())}",
            f"  Non-members:   {int((1 - self.result.ground_truth).sum())}",
            "",
            "--- Overall Metrics ---",
        ]
        for key, val in metrics.items():
            lines.append(f"  {key:<20s}: {val:.4f}")

        lines.append("")
        lines.append("--- Per-Class Breakdown ---")
        for cls in sorted(per_class.keys()):
            m = per_class[cls]
            lines.append(
                f"  Class {cls:>3d}:  acc={m['accuracy']:.3f}  "
                f"auc={m['auc_roc']:.3f}  n={m['n_samples']}"
            )

        lines.append("")
        lines.append("--- Metadata ---")
        for key, val in self.result.metadata.items():
            lines.append(f"  {key}: {val}")

        lines.append("")
        path.write_text("\n".join(lines))

run(member_loader: DataLoader, nonmember_loader: DataLoader) -> AttackResult

Execute the full shadow-model MIA pipeline.

Steps: 1. Train shadow models (or use pre-trained ones) 2. Collect (output, membership) pairs from all shadows 3. Train the attack MLP on shadow data 4. Use the attack MLP to classify target model outputs

Source code in src/auditml/attacks/mia_shadow.py
def run(
    self,
    member_loader: DataLoader,
    nonmember_loader: DataLoader,
) -> AttackResult:
    """Execute the full shadow-model MIA pipeline.

    Steps:
        1. Train shadow models (or use pre-trained ones)
        2. Collect (output, membership) pairs from all shadows
        3. Train the attack MLP on shadow data
        4. Use the attack MLP to classify target model outputs
    """
    # Step 1: Get shadow models with their data
    shadow_data = self._get_shadow_data()

    # Step 2: Collect attack training data from shadow models
    attack_features, attack_labels = self._collect_attack_data(shadow_data)
    logger.info(
        "Collected %d attack training samples (%d members, %d non-members)",
        len(attack_labels),
        int(attack_labels.sum()),
        int((1 - attack_labels).sum()),
    )

    # Step 3: Train the attack model
    self.attack_model = self._train_attack_model(attack_features, attack_labels)

    # Step 4: Attack the target model
    member_probs, _, member_true_labels = self.get_model_outputs(member_loader)
    nonmember_probs, _, nonmember_true_labels = self.get_model_outputs(nonmember_loader)

    # Store class labels for per-class evaluation
    self.member_labels = member_true_labels
    self.nonmember_labels = nonmember_true_labels

    # Build ground truth: 1 = member, 0 = non-member
    ground_truth = np.concatenate([
        np.ones(len(member_probs)),
        np.zeros(len(nonmember_probs)),
    ])

    # Get attack model predictions
    all_probs = np.concatenate([member_probs, nonmember_probs])
    confidence_scores = self._attack_predict(all_probs)
    predictions = (confidence_scores >= 0.5).astype(np.int32)

    # Store per-group confidence for visualization
    self.member_confidence = confidence_scores[:len(member_probs)]
    self.nonmember_confidence = confidence_scores[len(member_probs):]

    self.result = AttackResult(
        predictions=predictions,
        ground_truth=ground_truth,
        confidence_scores=confidence_scores,
        attack_name=self.attack_name,
        metadata={
            "num_shadow_models": self.num_shadows,
            "shadow_epochs": self.shadow_epochs,
            "attack_train_samples": len(attack_labels),
            "member_mean_confidence": float(confidence_scores[:len(member_probs)].mean()),
            "nonmember_mean_confidence": float(confidence_scores[len(member_probs):].mean()),
        },
    )
    return self.result

evaluate_per_class() -> dict[int, dict[str, float]]

Compute evaluation metrics separately for each class.

Groups all samples by their original class label and computes the full metric suite for each class. This reveals which classes are most vulnerable to the shadow model attack.

Returns:

Type Description
dict[int, dict[str, float]]

Mapping from class label to metric dictionary.

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/mia_shadow.py
def evaluate_per_class(self) -> dict[int, dict[str, float]]:
    """Compute evaluation metrics **separately for each class**.

    Groups all samples by their original class label and computes the
    full metric suite for each class. This reveals which classes are
    most vulnerable to the shadow model attack.

    Returns
    -------
    dict[int, dict[str, float]]
        Mapping from class label to metric dictionary.

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None:
        raise RuntimeError("Call run() before evaluate_per_class().")

    all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
    unique_classes = np.unique(all_labels)

    per_class: dict[int, dict[str, float]] = {}
    for cls in unique_classes:
        mask = all_labels == cls
        preds_cls = self.result.predictions[mask]
        gt_cls = self.result.ground_truth[mask]
        scores_cls = self.result.confidence_scores[mask]

        if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
            per_class[int(cls)] = {
                "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0,
                "auc_roc": 0.0,
                "auc_pr": 0.0,
                "tpr_at_1fpr": 0.0,
                "tpr_at_01fpr": 0.0,
                "n_samples": int(mask.sum()),
            }
            continue

        metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
        metrics["n_samples"] = int(mask.sum())
        per_class[int(cls)] = metrics

    return per_class

generate_report(output_dir: str | Path) -> Path

Generate a complete evaluation report with metrics and plots.

Creates the following files in output_dir:

  • metrics.json — overall evaluation metrics
  • per_class_metrics.json — per-class breakdown
  • roc_curve.png — ROC curve plot
  • confidence_distributions.png — histogram of attack confidence
  • per_class_accuracy.png — bar chart of per-class accuracy
  • summary.txt — human-readable text summary

Parameters:

Name Type Description Default
output_dir str | Path

Directory where all report files are saved.

required

Returns:

Type Description
Path

The output directory.

Source code in src/auditml/attacks/mia_shadow.py
def generate_report(self, output_dir: str | Path) -> Path:
    """Generate a complete evaluation report with metrics and plots.

    Creates the following files in *output_dir*:

    - ``metrics.json`` — overall evaluation metrics
    - ``per_class_metrics.json`` — per-class breakdown
    - ``roc_curve.png`` — ROC curve plot
    - ``confidence_distributions.png`` — histogram of attack confidence
    - ``per_class_accuracy.png`` — bar chart of per-class accuracy
    - ``summary.txt`` — human-readable text summary

    Parameters
    ----------
    output_dir:
        Directory where all report files are saved.

    Returns
    -------
    Path
        The output directory.
    """
    if self.result is None:
        raise RuntimeError("Call run() before generate_report().")

    from auditml.attacks.visualization import (
        plot_per_class_metrics,
        plot_roc_curve,
        plot_score_distributions,
    )

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    # 1. Overall metrics
    metrics = self.evaluate()
    with open(out / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    # 2. Per-class metrics
    per_class = self.evaluate_per_class()
    per_class_str = {str(k): v for k, v in per_class.items()}
    with open(out / "per_class_metrics.json", "w") as f:
        json.dump(per_class_str, f, indent=2)

    # 3. ROC curve
    plot_roc_curve(
        ground_truth=self.result.ground_truth,
        confidence_scores=self.result.confidence_scores,
        title="ROC Curve — Shadow Model MIA",
        save_path=out / "roc_curve.png",
    )

    # 4. Confidence distribution histogram
    plot_score_distributions(
        member_scores=self.member_confidence,
        nonmember_scores=self.nonmember_confidence,
        metric_name="attack confidence",
        save_path=out / "confidence_distributions.png",
        title="Attack Confidence Distribution — Shadow Model MIA",
    )

    # 5. Per-class accuracy bar chart
    plot_per_class_metrics(
        per_class_metrics=per_class,
        save_path=out / "per_class_accuracy.png",
    )

    # 6. Summary text
    self._write_summary(out / "summary.txt", metrics, per_class)

    return out

Model Inversion

auditml.attacks.model_inversion.ModelInversion

Bases: BaseAttack

Gradient-based Model Inversion attack.

For each target class, optimises a synthetic image so that the model classifies it with maximum confidence. The reconstructed images reveal what the model has learned — and potentially memorised — about each class.

Parameters:

Name Type Description Default
target_model Module

The trained model to attack (white-box — needs gradients).

required
config

Full AuditML configuration.

None
device device | str

Torch device.

'cpu'
input_shape tuple[int, ...] | None

Shape of the model's input, e.g. (1, 28, 28) for MNIST or (3, 32, 32) for CIFAR. If None, inferred from the dataset name in config.

None
Source code in src/auditml/attacks/model_inversion.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
class ModelInversion(BaseAttack):
    """Gradient-based Model Inversion attack.

    For each target class, optimises a synthetic image so that the model
    classifies it with maximum confidence.  The reconstructed images
    reveal what the model has learned — and potentially memorised — about
    each class.

    Parameters
    ----------
    target_model:
        The trained model to attack (white-box — needs gradients).
    config:
        Full AuditML configuration.
    device:
        Torch device.
    input_shape:
        Shape of the model's input, e.g. ``(1, 28, 28)`` for MNIST or
        ``(3, 32, 32)`` for CIFAR. If ``None``, inferred from the
        dataset name in config.
    """

    attack_name = "model_inversion"

    def __init__(
        self,
        target_model: nn.Module,
        config=None,
        device: torch.device | str = "cpu",
        input_shape: tuple[int, ...] | None = None,
        *,
        num_iterations: int | None = None,
        learning_rate: float | None = None,
        lambda_tv: float | None = None,
        lambda_l2: float | None = None,
        target_class: int | None = None,
        num_classes: int | None = None,
    ) -> None:
        super().__init__(target_model, config, device)

        # Explicit params take priority; fall back to config; then hardcoded defaults.
        cfg_p = config.attack_params.model_inversion if config is not None else None
        self.num_iterations = num_iterations or (cfg_p.num_iterations if cfg_p else 500)
        self.lr = learning_rate or (cfg_p.learning_rate if cfg_p else 0.1)
        self.lambda_tv = (
            lambda_tv if lambda_tv is not None else (cfg_p.lambda_tv if cfg_p else 0.001)
        )
        self.lambda_l2 = lambda_l2 if lambda_l2 is not None else (cfg_p.lambda_l2 if cfg_p else 0.0)
        self.target_class = target_class or (cfg_p.target_class if cfg_p else None)
        self.num_classes = num_classes or (config.model.num_classes if config is not None else 10)

        # Determine input shape — explicit > config dataset > error
        if input_shape is not None:
            self.input_shape = input_shape
        elif config is not None:
            dataset_name = config.data.dataset.value
            if dataset_name in DATASET_INFO:
                self.input_shape = DATASET_INFO[dataset_name].input_shape
            else:
                raise ValueError(
                    f"Cannot infer input_shape for dataset {dataset_name!r}. "
                    "Pass input_shape explicitly."
                )
        else:
            # Will be auto-detected from the first batch in run()
            self.input_shape = None

        # Populated during run()
        self.reconstructions: dict[int, np.ndarray] = {}
        self.reconstruction_confidences: dict[int, float] = {}
        # Stored during run() for visualization
        self.member_scores: np.ndarray | None = None
        self.nonmember_scores: np.ndarray | None = None

    # ------------------------------------------------------------------
    # Main attack logic
    # ------------------------------------------------------------------

    def run(
        self,
        member_loader: DataLoader,
        nonmember_loader: DataLoader,
    ) -> AttackResult:
        """Execute the model inversion attack.

        For each target class, reconstruct an image and measure how
        confidently the model classifies it. Then use the member and
        non-member loaders to evaluate: does the model assign higher
        confidence to reconstructions of classes it trained on?

        The ``member_loader`` and ``nonmember_loader`` are used to
        compute a membership-like signal: for each sample, we measure
        the similarity between the model's output on that sample and
        the reconstructed class prototype. Members tend to produce
        outputs closer to the reconstruction.
        """
        # Auto-detect input shape from data if not provided at init
        if self.input_shape is None:
            first_batch = next(iter(member_loader))
            self.input_shape = tuple(first_batch[0].shape[1:])
            logger.info("Auto-detected input_shape=%s from data", self.input_shape)

        # Determine which classes to invert
        if self.target_class is not None:
            classes_to_invert = [self.target_class]
        else:
            classes_to_invert = list(range(self.num_classes))

        # Step 1: Reconstruct images for each target class
        from tqdm import tqdm as _tqdm
        for cls in _tqdm(classes_to_invert, desc="Inverting classes", unit="class", ncols=70):
            logger.info("Inverting class %d/%d ...", cls + 1, len(classes_to_invert))
            recon, confidence = self.invert_class(cls)
            self.reconstructions[cls] = recon.detach().cpu().numpy()
            self.reconstruction_confidences[cls] = confidence

        # Step 2: Compute membership signal using reconstruction similarity
        self.member_scores = self._compute_similarity_scores(member_loader)
        self.nonmember_scores = self._compute_similarity_scores(nonmember_loader)
        member_scores = self.member_scores
        nonmember_scores = self.nonmember_scores

        # Build ground truth and combined scores
        ground_truth = np.concatenate([
            np.ones(len(member_scores)),
            np.zeros(len(nonmember_scores)),
        ])
        all_scores = np.concatenate([member_scores, nonmember_scores])

        # Threshold at median for binary predictions
        threshold = float(np.median(all_scores))
        predictions = (all_scores >= threshold).astype(np.int32)

        self.result = AttackResult(
            predictions=predictions,
            ground_truth=ground_truth,
            confidence_scores=all_scores,
            attack_name=self.attack_name,
            metadata={
                "num_classes_inverted": len(classes_to_invert),
                "num_iterations": self.num_iterations,
                "lambda_tv": self.lambda_tv,
                "lambda_l2": self.lambda_l2,
                "reconstruction_confidences": self.reconstruction_confidences,
                "mean_member_similarity": float(member_scores.mean()),
                "mean_nonmember_similarity": float(nonmember_scores.mean()),
            },
        )
        return self.result

    # ------------------------------------------------------------------
    # Core inversion — reconstruct one class
    # ------------------------------------------------------------------

    def invert_class(
        self,
        target_class: int,
        num_iterations: int | None = None,
    ) -> tuple[torch.Tensor, float]:
        """Reconstruct an image for a single target class.

        Parameters
        ----------
        target_class:
            The class label to reconstruct.
        num_iterations:
            Override the config value. Uses ``self.num_iterations`` if None.

        Returns
        -------
        (reconstructed_image, confidence)
            - reconstructed_image: tensor of shape ``(1, C, H, W)``
            - confidence: model's softmax probability for target_class
        """
        if num_iterations is None:
            num_iterations = self.num_iterations

        # Ensure model is in eval mode but gradients can flow through
        self.target_model.eval()

        # Start from random noise, requires_grad so we can optimise it
        x = torch.randn(1, *self.input_shape, device=self.device, requires_grad=True)

        optimizer = torch.optim.Adam([x], lr=self.lr)

        best_confidence = 0.0
        best_x = x.detach().clone()

        from tqdm import tqdm
        pbar = tqdm(range(num_iterations), desc=f"  class {target_class}", leave=False,
                    unit="step", ncols=70)
        for i in pbar:
            optimizer.zero_grad()

            # Forward pass
            logits = self.target_model(x)
            probs = F.softmax(logits, dim=1)

            # Classification loss: maximise P(target_class)
            # Equivalent to minimising -log(P(target_class))
            cls_loss = -torch.log(probs[0, target_class] + 1e-10)

            # Regularisation
            reg_loss = torch.tensor(0.0, device=self.device)
            if self.lambda_tv > 0:
                reg_loss = reg_loss + self.lambda_tv * self._total_variation(x)
            if self.lambda_l2 > 0:
                reg_loss = reg_loss + self.lambda_l2 * torch.norm(x)

            total_loss = cls_loss + reg_loss
            total_loss.backward()
            optimizer.step()

            # Track best reconstruction
            current_confidence = probs[0, target_class].item()
            if current_confidence > best_confidence:
                best_confidence = current_confidence
                best_x = x.detach().clone()
            pbar.set_postfix(conf=f"{current_confidence:.3f}")

        logger.info(
            "Class %d: confidence=%.4f after %d iterations",
            target_class, best_confidence, num_iterations,
        )
        return best_x, best_confidence

    # ------------------------------------------------------------------
    # Membership signal via reconstruction similarity
    # ------------------------------------------------------------------

    @torch.no_grad()
    def _compute_similarity_scores(self, loader: DataLoader) -> np.ndarray:
        """Compute how similar each sample's output is to its class reconstruction.

        For each sample, we measure the cosine similarity between the
        model's softmax output on that sample and the softmax output
        on the reconstructed image for that sample's class. Members
        tend to have higher similarity because the model has memorised
        patterns specific to training data.

        Parameters
        ----------
        loader:
            DataLoader of samples to score.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — similarity score per sample.
        """
        self.target_model.eval()
        all_scores: list[float] = []

        # Pre-compute reconstruction output vectors for each class
        recon_outputs: dict[int, np.ndarray] = {}
        for cls, recon_img in self.reconstructions.items():
            recon_tensor = torch.tensor(recon_img, dtype=torch.float32).to(self.device)
            logits = self.target_model(recon_tensor)
            recon_outputs[cls] = F.softmax(logits, dim=1).cpu().numpy()[0]

        for inputs, targets in loader:
            inputs = inputs.to(self.device)
            logits = self.target_model(inputs)
            probs = F.softmax(logits, dim=1).cpu().numpy()

            for i in range(len(targets)):
                cls = targets[i].item()
                if cls in recon_outputs:
                    # Cosine similarity between sample output and reconstruction output
                    sample_vec = probs[i]
                    recon_vec = recon_outputs[cls]
                    cos_sim = float(
                        np.dot(sample_vec, recon_vec)
                        / (np.linalg.norm(sample_vec) * np.linalg.norm(recon_vec) + 1e-10)
                    )
                    all_scores.append(cos_sim)
                else:
                    # Class wasn't inverted — use neutral score
                    all_scores.append(0.5)

        return np.array(all_scores)

    # ------------------------------------------------------------------
    # Regularisation helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _total_variation(x: torch.Tensor) -> torch.Tensor:
        """Compute Total Variation loss for an image tensor.

        TV loss encourages spatial smoothness by penalising large
        differences between neighbouring pixels. This prevents the
        optimisation from producing noisy, unrealistic images.

        Parameters
        ----------
        x:
            Image tensor of shape ``(B, C, H, W)``.

        Returns
        -------
        torch.Tensor
            Scalar TV loss.
        """
        diff_h = x[:, :, 1:, :] - x[:, :, :-1, :]  # vertical differences
        diff_w = x[:, :, :, 1:] - x[:, :, :, :-1]  # horizontal differences
        return torch.mean(diff_h ** 2) + torch.mean(diff_w ** 2)

    # ------------------------------------------------------------------
    # Report generation
    # ------------------------------------------------------------------

    def generate_report(self, output_dir: str | Path) -> Path:
        """Generate a complete model inversion report.

        Creates the following files in *output_dir*:

        - ``metrics.json`` — overall evaluation metrics
        - ``reconstructions.png`` — grid of reconstructed images
        - ``reconstruction_confidence.png`` — per-class confidence bar chart
        - ``similarity_distributions.png`` — member vs non-member similarity
        - ``roc_curve.png`` — ROC curve
        - ``summary.txt`` — human-readable text summary

        Parameters
        ----------
        output_dir:
            Directory where all report files are saved.

        Returns
        -------
        Path
            The output directory.
        """
        if self.result is None:
            raise RuntimeError("Call run() before generate_report().")

        from auditml.attacks.visualization import (
            plot_reconstruction_confidence,
            plot_reconstructions,
            plot_roc_curve,
            plot_score_distributions,
        )

        out = Path(output_dir)
        out.mkdir(parents=True, exist_ok=True)

        # 1. Overall metrics + SSIM reconstruction quality
        metrics = self.evaluate()

        # Pairwise SSIM between class reconstructions — higher diversity = better inversion
        recon_list = list(self.reconstructions.values())
        if len(recon_list) >= 2:
            flat = [r.flatten() for r in recon_list]
            # Compare each reconstruction against the mean reconstruction
            mean_recon = np.mean(np.stack(flat), axis=0)
            ssim_vs_mean = _batch_ssim(flat, [mean_recon] * len(flat))
            metrics["mean_ssim_vs_mean_recon"] = float(np.mean(ssim_vs_mean))
            metrics["reconstruction_ssim_scores"] = {
                int(cls): float(s)
                for cls, s in zip(self.reconstructions.keys(), ssim_vs_mean)
            }

        with open(out / "metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)

        # 2. Reconstructed images grid
        plot_reconstructions(
            reconstructions=self.reconstructions,
            confidences=self.reconstruction_confidences,
            save_path=out / "reconstructions.png",
        )

        # 3. Reconstruction confidence bar chart
        plot_reconstruction_confidence(
            confidences=self.reconstruction_confidences,
            save_path=out / "reconstruction_confidence.png",
        )

        # 4. Similarity distribution histogram
        plot_score_distributions(
            member_scores=self.member_scores,
            nonmember_scores=self.nonmember_scores,
            metric_name="cosine similarity",
            save_path=out / "similarity_distributions.png",
            title="Similarity Distribution — Model Inversion",
        )

        # 5. ROC curve
        plot_roc_curve(
            ground_truth=self.result.ground_truth,
            confidence_scores=self.result.confidence_scores,
            title="ROC Curve — Model Inversion",
            save_path=out / "roc_curve.png",
        )

        # 6. Summary text
        self._write_summary(out / "summary.txt", metrics)

        return out

    def _write_summary(
        self,
        path: Path,
        metrics: dict[str, float],
    ) -> None:
        """Write a human-readable text summary."""
        lines = [
            "=" * 60,
            "AuditML — Model Inversion Report",
            "=" * 60,
            "",
            f"Classes inverted: {len(self.reconstructions)}",
            f"Iterations:       {self.num_iterations}",
            f"Learning rate:    {self.lr}",
            f"Lambda TV:        {self.lambda_tv}",
            f"Lambda L2:        {self.lambda_l2}",
            f"Input shape:      {self.input_shape}",
            f"Total samples:    {len(self.result.predictions)}",
            f"  Members:        {int(self.result.ground_truth.sum())}",
            f"  Non-members:    {int((1 - self.result.ground_truth).sum())}",
            "",
            "--- Reconstruction Confidences ---",
        ]
        for cls in sorted(self.reconstruction_confidences.keys()):
            lines.append(f"  Class {cls:>3d}: {self.reconstruction_confidences[cls]:.4f}")

        lines.append("")
        lines.append("--- Overall Metrics ---")
        for key, val in metrics.items():
            if isinstance(val, dict):
                lines.append(f"  {key:<20s}: {val}")
            else:
                lines.append(f"  {key:<20s}: {val:.4f}")

        lines.append("")
        lines.append("--- Metadata ---")
        for key, val in self.result.metadata.items():
            if key != "reconstruction_confidences":
                lines.append(f"  {key}: {val}")

        lines.append("")
        path.write_text("\n".join(lines))

run(member_loader: DataLoader, nonmember_loader: DataLoader) -> AttackResult

Execute the model inversion attack.

For each target class, reconstruct an image and measure how confidently the model classifies it. Then use the member and non-member loaders to evaluate: does the model assign higher confidence to reconstructions of classes it trained on?

The member_loader and nonmember_loader are used to compute a membership-like signal: for each sample, we measure the similarity between the model's output on that sample and the reconstructed class prototype. Members tend to produce outputs closer to the reconstruction.

Source code in src/auditml/attacks/model_inversion.py
def run(
    self,
    member_loader: DataLoader,
    nonmember_loader: DataLoader,
) -> AttackResult:
    """Execute the model inversion attack.

    For each target class, reconstruct an image and measure how
    confidently the model classifies it. Then use the member and
    non-member loaders to evaluate: does the model assign higher
    confidence to reconstructions of classes it trained on?

    The ``member_loader`` and ``nonmember_loader`` are used to
    compute a membership-like signal: for each sample, we measure
    the similarity between the model's output on that sample and
    the reconstructed class prototype. Members tend to produce
    outputs closer to the reconstruction.
    """
    # Auto-detect input shape from data if not provided at init
    if self.input_shape is None:
        first_batch = next(iter(member_loader))
        self.input_shape = tuple(first_batch[0].shape[1:])
        logger.info("Auto-detected input_shape=%s from data", self.input_shape)

    # Determine which classes to invert
    if self.target_class is not None:
        classes_to_invert = [self.target_class]
    else:
        classes_to_invert = list(range(self.num_classes))

    # Step 1: Reconstruct images for each target class
    from tqdm import tqdm as _tqdm
    for cls in _tqdm(classes_to_invert, desc="Inverting classes", unit="class", ncols=70):
        logger.info("Inverting class %d/%d ...", cls + 1, len(classes_to_invert))
        recon, confidence = self.invert_class(cls)
        self.reconstructions[cls] = recon.detach().cpu().numpy()
        self.reconstruction_confidences[cls] = confidence

    # Step 2: Compute membership signal using reconstruction similarity
    self.member_scores = self._compute_similarity_scores(member_loader)
    self.nonmember_scores = self._compute_similarity_scores(nonmember_loader)
    member_scores = self.member_scores
    nonmember_scores = self.nonmember_scores

    # Build ground truth and combined scores
    ground_truth = np.concatenate([
        np.ones(len(member_scores)),
        np.zeros(len(nonmember_scores)),
    ])
    all_scores = np.concatenate([member_scores, nonmember_scores])

    # Threshold at median for binary predictions
    threshold = float(np.median(all_scores))
    predictions = (all_scores >= threshold).astype(np.int32)

    self.result = AttackResult(
        predictions=predictions,
        ground_truth=ground_truth,
        confidence_scores=all_scores,
        attack_name=self.attack_name,
        metadata={
            "num_classes_inverted": len(classes_to_invert),
            "num_iterations": self.num_iterations,
            "lambda_tv": self.lambda_tv,
            "lambda_l2": self.lambda_l2,
            "reconstruction_confidences": self.reconstruction_confidences,
            "mean_member_similarity": float(member_scores.mean()),
            "mean_nonmember_similarity": float(nonmember_scores.mean()),
        },
    )
    return self.result

invert_class(target_class: int, num_iterations: int | None = None) -> tuple[torch.Tensor, float]

Reconstruct an image for a single target class.

Parameters:

Name Type Description Default
target_class int

The class label to reconstruct.

required
num_iterations int | None

Override the config value. Uses self.num_iterations if None.

None

Returns:

Type Description
(reconstructed_image, confidence)
  • reconstructed_image: tensor of shape (1, C, H, W)
  • confidence: model's softmax probability for target_class
Source code in src/auditml/attacks/model_inversion.py
def invert_class(
    self,
    target_class: int,
    num_iterations: int | None = None,
) -> tuple[torch.Tensor, float]:
    """Reconstruct an image for a single target class.

    Parameters
    ----------
    target_class:
        The class label to reconstruct.
    num_iterations:
        Override the config value. Uses ``self.num_iterations`` if None.

    Returns
    -------
    (reconstructed_image, confidence)
        - reconstructed_image: tensor of shape ``(1, C, H, W)``
        - confidence: model's softmax probability for target_class
    """
    if num_iterations is None:
        num_iterations = self.num_iterations

    # Ensure model is in eval mode but gradients can flow through
    self.target_model.eval()

    # Start from random noise, requires_grad so we can optimise it
    x = torch.randn(1, *self.input_shape, device=self.device, requires_grad=True)

    optimizer = torch.optim.Adam([x], lr=self.lr)

    best_confidence = 0.0
    best_x = x.detach().clone()

    from tqdm import tqdm
    pbar = tqdm(range(num_iterations), desc=f"  class {target_class}", leave=False,
                unit="step", ncols=70)
    for i in pbar:
        optimizer.zero_grad()

        # Forward pass
        logits = self.target_model(x)
        probs = F.softmax(logits, dim=1)

        # Classification loss: maximise P(target_class)
        # Equivalent to minimising -log(P(target_class))
        cls_loss = -torch.log(probs[0, target_class] + 1e-10)

        # Regularisation
        reg_loss = torch.tensor(0.0, device=self.device)
        if self.lambda_tv > 0:
            reg_loss = reg_loss + self.lambda_tv * self._total_variation(x)
        if self.lambda_l2 > 0:
            reg_loss = reg_loss + self.lambda_l2 * torch.norm(x)

        total_loss = cls_loss + reg_loss
        total_loss.backward()
        optimizer.step()

        # Track best reconstruction
        current_confidence = probs[0, target_class].item()
        if current_confidence > best_confidence:
            best_confidence = current_confidence
            best_x = x.detach().clone()
        pbar.set_postfix(conf=f"{current_confidence:.3f}")

    logger.info(
        "Class %d: confidence=%.4f after %d iterations",
        target_class, best_confidence, num_iterations,
    )
    return best_x, best_confidence

generate_report(output_dir: str | Path) -> Path

Generate a complete model inversion report.

Creates the following files in output_dir:

  • metrics.json — overall evaluation metrics
  • reconstructions.png — grid of reconstructed images
  • reconstruction_confidence.png — per-class confidence bar chart
  • similarity_distributions.png — member vs non-member similarity
  • roc_curve.png — ROC curve
  • summary.txt — human-readable text summary

Parameters:

Name Type Description Default
output_dir str | Path

Directory where all report files are saved.

required

Returns:

Type Description
Path

The output directory.

Source code in src/auditml/attacks/model_inversion.py
def generate_report(self, output_dir: str | Path) -> Path:
    """Generate a complete model inversion report.

    Creates the following files in *output_dir*:

    - ``metrics.json`` — overall evaluation metrics
    - ``reconstructions.png`` — grid of reconstructed images
    - ``reconstruction_confidence.png`` — per-class confidence bar chart
    - ``similarity_distributions.png`` — member vs non-member similarity
    - ``roc_curve.png`` — ROC curve
    - ``summary.txt`` — human-readable text summary

    Parameters
    ----------
    output_dir:
        Directory where all report files are saved.

    Returns
    -------
    Path
        The output directory.
    """
    if self.result is None:
        raise RuntimeError("Call run() before generate_report().")

    from auditml.attacks.visualization import (
        plot_reconstruction_confidence,
        plot_reconstructions,
        plot_roc_curve,
        plot_score_distributions,
    )

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    # 1. Overall metrics + SSIM reconstruction quality
    metrics = self.evaluate()

    # Pairwise SSIM between class reconstructions — higher diversity = better inversion
    recon_list = list(self.reconstructions.values())
    if len(recon_list) >= 2:
        flat = [r.flatten() for r in recon_list]
        # Compare each reconstruction against the mean reconstruction
        mean_recon = np.mean(np.stack(flat), axis=0)
        ssim_vs_mean = _batch_ssim(flat, [mean_recon] * len(flat))
        metrics["mean_ssim_vs_mean_recon"] = float(np.mean(ssim_vs_mean))
        metrics["reconstruction_ssim_scores"] = {
            int(cls): float(s)
            for cls, s in zip(self.reconstructions.keys(), ssim_vs_mean)
        }

    with open(out / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    # 2. Reconstructed images grid
    plot_reconstructions(
        reconstructions=self.reconstructions,
        confidences=self.reconstruction_confidences,
        save_path=out / "reconstructions.png",
    )

    # 3. Reconstruction confidence bar chart
    plot_reconstruction_confidence(
        confidences=self.reconstruction_confidences,
        save_path=out / "reconstruction_confidence.png",
    )

    # 4. Similarity distribution histogram
    plot_score_distributions(
        member_scores=self.member_scores,
        nonmember_scores=self.nonmember_scores,
        metric_name="cosine similarity",
        save_path=out / "similarity_distributions.png",
        title="Similarity Distribution — Model Inversion",
    )

    # 5. ROC curve
    plot_roc_curve(
        ground_truth=self.result.ground_truth,
        confidence_scores=self.result.confidence_scores,
        title="ROC Curve — Model Inversion",
        save_path=out / "roc_curve.png",
    )

    # 6. Summary text
    self._write_summary(out / "summary.txt", metrics)

    return out

Attribute Inference

auditml.attacks.attribute_inference.AttributeInference

Bases: BaseAttack

Attribute inference attack via model output analysis.

For each sample the attacker observes the target model's softmax probability vector and tries to predict a sensitive attribute that the model was not designed to reveal. The attack trains a small MLP on the member (training) data, then evaluates whether members' attributes are more predictable than non-members'.

Parameters:

Name Type Description Default
target_model Module

The trained model being audited.

required
config

Full AuditML configuration.

None
device device | str

Torch device.

'cpu'
num_groups int | None

Override for the number of sensitive-attribute groups. If None, determined automatically from the dataset.

None
class_to_group dict[int, int] | None

Explicit mapping {class_label: group_id}. If None, a default mapping is used based on the dataset.

None
Source code in src/auditml/attacks/attribute_inference.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
class AttributeInference(BaseAttack):
    """Attribute inference attack via model output analysis.

    For each sample the attacker observes the target model's softmax
    probability vector and tries to predict a sensitive attribute that
    the model was *not* designed to reveal.  The attack trains a small
    MLP on the member (training) data, then evaluates whether members'
    attributes are more predictable than non-members'.

    Parameters
    ----------
    target_model:
        The trained model being audited.
    config:
        Full AuditML configuration.
    device:
        Torch device.
    num_groups:
        Override for the number of sensitive-attribute groups.
        If ``None``, determined automatically from the dataset.
    class_to_group:
        Explicit mapping ``{class_label: group_id}``. If ``None``,
        a default mapping is used based on the dataset.
    """

    attack_name = "attribute_inference"

    def __init__(
        self,
        target_model: nn.Module,
        config=None,
        device: torch.device | str = "cpu",
        num_groups: int | None = None,
        class_to_group: dict[int, int] | None = None,
        *,
        num_classes: int | None = None,
        sensitive_attribute: str | None = None,
        attack_model_type: str | None = None,
    ) -> None:
        super().__init__(target_model, config, device)

        # Explicit params take priority; fall back to config; then hardcoded defaults.
        cfg_p = config.attack_params.attribute_inference if config is not None else None
        self.sensitive_attribute = (
            sensitive_attribute or (cfg_p.sensitive_attribute if cfg_p else "label")
        )
        self.attack_model_type = attack_model_type or (cfg_p.attack_model if cfg_p else "mlp")
        self.num_classes = num_classes or (config.model.num_classes if config is not None else 10)
        dataset_name = config.data.dataset.value if config is not None else None

        # Build the class → group mapping
        if class_to_group is not None:
            self.class_to_group = class_to_group
            self.num_groups = len(set(class_to_group.values()))
        elif num_groups is not None:
            self.num_groups = num_groups
            self.class_to_group = {
                c: c % num_groups for c in range(self.num_classes)
            }
        elif dataset_name in _DEFAULT_GROUPS and _DEFAULT_GROUPS[dataset_name] is not None:
            self.class_to_group = _DEFAULT_GROUPS[dataset_name]
            self.num_groups = len(set(self.class_to_group.values()))
        else:
            # Fallback: auto-generate with ≈5 classes per group
            self.num_groups = max(2, self.num_classes // 5)
            self.class_to_group = {
                c: c % self.num_groups for c in range(self.num_classes)
            }

        # Populated during run()
        self.attack_model: AttributeAttackMLP | None = None
        self.member_labels: np.ndarray | None = None
        self.nonmember_labels: np.ndarray | None = None
        self.member_confidence: np.ndarray | None = None
        self.nonmember_confidence: np.ndarray | None = None

    # ------------------------------------------------------------------
    # Main attack logic
    # ------------------------------------------------------------------

    def run(
        self,
        member_loader: DataLoader,
        nonmember_loader: DataLoader,
    ) -> AttackResult:
        """Execute the attribute inference attack.

        Steps:

        1. Extract the target model's softmax outputs for all samples.
        2. Map class labels → sensitive attribute (group labels).
        3. Train an attack MLP on *member* outputs → group.
        4. Score every sample by the confidence of the correct group
           prediction.  Members should score higher.

        Parameters
        ----------
        member_loader:
            DataLoader over training (member) samples.
        nonmember_loader:
            DataLoader over non-member samples.

        Returns
        -------
        AttackResult
        """
        # Step 1: Extract model outputs
        member_probs, _, member_true_labels = self.get_model_outputs(member_loader)
        nonmember_probs, _, nonmember_true_labels = self.get_model_outputs(nonmember_loader)

        self.member_labels = member_true_labels
        self.nonmember_labels = nonmember_true_labels

        # Step 2: Sensitive attribute labels (group assignments)
        member_groups = self._labels_to_groups(member_true_labels)
        nonmember_groups = self._labels_to_groups(nonmember_true_labels)

        logger.info(
            "Training attribute attack model: %d groups, %d member samples",
            self.num_groups, len(member_probs),
        )

        # Step 3: Train attack model on member data
        self.attack_model = self._train_attack_model(member_probs, member_groups)

        # Step 4: Predict attribute confidence for both sets
        member_attr_conf = self._predict_attribute_confidence(
            member_probs, member_groups,
        )
        nonmember_attr_conf = self._predict_attribute_confidence(
            nonmember_probs, nonmember_groups,
        )

        self.member_confidence = member_attr_conf
        self.nonmember_confidence = nonmember_attr_conf

        logger.info(
            "Attribute confidence — members: %.4f, non-members: %.4f",
            float(member_attr_conf.mean()), float(nonmember_attr_conf.mean()),
        )

        # Step 5: Build membership inference signal
        ground_truth = np.concatenate([
            np.ones(len(member_attr_conf)),
            np.zeros(len(nonmember_attr_conf)),
        ])
        all_scores = np.concatenate([member_attr_conf, nonmember_attr_conf])

        # Threshold at median for binary predictions
        threshold = float(np.median(all_scores))
        predictions = (all_scores >= threshold).astype(np.int32)

        self.result = AttackResult(
            predictions=predictions,
            ground_truth=ground_truth,
            confidence_scores=all_scores,
            attack_name=self.attack_name,
            metadata={
                "num_groups": self.num_groups,
                "sensitive_attribute": self.sensitive_attribute,
                "mean_member_attr_confidence": float(member_attr_conf.mean()),
                "mean_nonmember_attr_confidence": float(nonmember_attr_conf.mean()),
                # Store probs for per-group evaluation (prefixed with _ to
                # exclude from summary text)
                "_member_probs": member_probs,
                "_nonmember_probs": nonmember_probs,
            },
        )
        return self.result

    # ------------------------------------------------------------------
    # Label → group mapping
    # ------------------------------------------------------------------

    def _labels_to_groups(self, labels: np.ndarray) -> np.ndarray:
        """Convert class labels to sensitive-attribute group IDs.

        Parameters
        ----------
        labels:
            Integer class labels, shape ``(N,)``.

        Returns
        -------
        np.ndarray
            Integer group IDs, shape ``(N,)``.
        """
        return np.array([
            self.class_to_group.get(int(c), 0) for c in labels
        ])

    # ------------------------------------------------------------------
    # Attack model training
    # ------------------------------------------------------------------

    def _train_attack_model(
        self,
        probs: np.ndarray,
        groups: np.ndarray,
        epochs: int = 50,
        lr: float = 0.001,
    ) -> AttributeAttackMLP:
        """Train an MLP to predict the sensitive attribute.

        Parameters
        ----------
        probs:
            Softmax outputs from the target model, shape ``(N, C)``.
        groups:
            Sensitive attribute labels, shape ``(N,)``.
        epochs:
            Number of training epochs.
        lr:
            Learning rate.

        Returns
        -------
        AttributeAttackMLP
            The trained attack model (in eval mode).
        """
        model = AttributeAttackMLP(
            input_dim=probs.shape[1],
            num_groups=self.num_groups,
        )
        model.to(self.device)

        x = torch.tensor(probs, dtype=torch.float32)
        y = torch.tensor(groups, dtype=torch.long)
        dataset = TensorDataset(x, y)
        loader = DataLoader(dataset, batch_size=64, shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        model.train()
        for epoch in range(epochs):
            total_loss = 0.0
            for batch_x, batch_y in loader:
                batch_x = batch_x.to(self.device)
                batch_y = batch_y.to(self.device)

                optimizer.zero_grad()
                logits = model(batch_x)
                loss = criterion(logits, batch_y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

            if (epoch + 1) % 25 == 0:
                logger.debug(
                    "Attack model epoch %d/%d — loss: %.4f",
                    epoch + 1, epochs, total_loss / len(loader),
                )

        model.eval()
        return model

    # ------------------------------------------------------------------
    # Attribute prediction
    # ------------------------------------------------------------------

    @torch.no_grad()
    def _predict_attribute_confidence(
        self,
        probs: np.ndarray,
        true_groups: np.ndarray,
    ) -> np.ndarray:
        """Measure how confidently the attack model predicts each sample's attribute.

        For each sample, returns the softmax probability that the attack
        model assigns to the *correct* group.  Higher confidence means
        the model's output is more informative about the sensitive
        attribute — a sign that the sample was in the training data.

        Parameters
        ----------
        probs:
            Softmax outputs from the target model, shape ``(N, C)``.
        true_groups:
            Ground-truth group labels, shape ``(N,)``.

        Returns
        -------
        np.ndarray
            Shape ``(N,)`` — confidence in the correct group for each
            sample.
        """
        self.attack_model.eval()
        x = torch.tensor(probs, dtype=torch.float32).to(self.device)
        logits = self.attack_model(x)
        pred_probs = F.softmax(logits, dim=1).cpu().numpy()

        # Confidence = probability assigned to the correct group
        confidences = pred_probs[np.arange(len(true_groups)), true_groups]
        return confidences

    # ------------------------------------------------------------------
    # Attribute prediction accuracy
    # ------------------------------------------------------------------

    @torch.no_grad()
    def predict_attributes(self, probs: np.ndarray) -> np.ndarray:
        """Predict the sensitive attribute for each sample.

        Parameters
        ----------
        probs:
            Softmax outputs from the target model, shape ``(N, C)``.

        Returns
        -------
        np.ndarray
            Predicted group IDs, shape ``(N,)``.
        """
        if self.attack_model is None:
            raise RuntimeError("Call run() before predict_attributes().")

        self.attack_model.eval()
        x = torch.tensor(probs, dtype=torch.float32).to(self.device)
        logits = self.attack_model(x)
        return logits.argmax(dim=1).cpu().numpy()

    # ------------------------------------------------------------------
    # Per-class evaluation (membership inference per original class)
    # ------------------------------------------------------------------

    def evaluate_per_class(self) -> dict[int, dict[str, float]]:
        """Compute evaluation metrics separately for each original class.

        Groups all samples by their original class label and computes the
        full metric suite for each class.  This reveals which classes are
        most vulnerable to the attribute inference attack.

        Returns
        -------
        dict[int, dict[str, float]]
            Mapping from class label to metric dictionary.

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None:
            raise RuntimeError("Call run() before evaluate_per_class().")

        all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
        unique_classes = np.unique(all_labels)

        per_class: dict[int, dict[str, float]] = {}
        for cls in unique_classes:
            mask = all_labels == cls
            preds_cls = self.result.predictions[mask]
            gt_cls = self.result.ground_truth[mask]
            scores_cls = self.result.confidence_scores[mask]

            if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
                per_class[int(cls)] = {
                    "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                    "precision": 0.0,
                    "recall": 0.0,
                    "f1": 0.0,
                    "auc_roc": 0.0,
                    "auc_pr": 0.0,
                    "tpr_at_1fpr": 0.0,
                    "tpr_at_01fpr": 0.0,
                    "n_samples": int(mask.sum()),
                }
                continue

            metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
            metrics["n_samples"] = int(mask.sum())
            per_class[int(cls)] = metrics

        return per_class

    # ------------------------------------------------------------------
    # Per-group evaluation (attribute prediction accuracy per group)
    # ------------------------------------------------------------------

    def evaluate_per_group(self) -> dict[str, dict[int, float]]:
        """Compute attribute prediction accuracy for each group.

        Returns separate accuracy dictionaries for members and
        non-members.  A gap between the two signals privacy leakage.

        Returns
        -------
        dict with keys ``"member"`` and ``"nonmember"``, each mapping
        group ID to prediction accuracy on that group.

        Raises
        ------
        RuntimeError
            If ``run()`` has not been called yet.
        """
        if self.result is None or self.attack_model is None:
            raise RuntimeError("Call run() before evaluate_per_group().")

        member_groups = self._labels_to_groups(self.member_labels)
        nonmember_groups = self._labels_to_groups(self.nonmember_labels)

        member_preds = self.predict_attributes(
            self._get_stored_probs("member"),
        )
        nonmember_preds = self.predict_attributes(
            self._get_stored_probs("nonmember"),
        )

        member_acc: dict[int, float] = {}
        for g in range(self.num_groups):
            mask = member_groups == g
            if mask.sum() > 0:
                member_acc[g] = float((member_preds[mask] == member_groups[mask]).mean())

        nonmember_acc: dict[int, float] = {}
        for g in range(self.num_groups):
            mask = nonmember_groups == g
            if mask.sum() > 0:
                nonmember_acc[g] = float((nonmember_preds[mask] == nonmember_groups[mask]).mean())

        return {"member": member_acc, "nonmember": nonmember_acc}

    def _get_stored_probs(self, split: str) -> np.ndarray:
        """Re-extract softmax probs from stored confidence data.

        Since we store per-sample confidence (scalar), but need the full
        probability vector for ``predict_attributes``, we store them
        during ``run()``.
        """
        # We need to store the probs during run() — add them to metadata
        key = f"_{split}_probs"
        if key not in self.result.metadata:
            raise RuntimeError(
                "Probabilities not stored. Ensure run() was called."
            )
        return self.result.metadata[key]

    # ------------------------------------------------------------------
    # Report generation
    # ------------------------------------------------------------------

    def generate_report(self, output_dir: str | Path) -> Path:
        """Generate a complete evaluation report with metrics and plots.

        Creates the following files in *output_dir*:

        - ``metrics.json`` — overall evaluation metrics
        - ``per_class_metrics.json`` — per-class breakdown
        - ``per_group_accuracy.json`` — per-group attribute accuracy
        - ``roc_curve.png`` — ROC curve plot
        - ``confidence_distributions.png`` — member vs non-member histogram
        - ``per_class_accuracy.png`` — bar chart of per-class accuracy
        - ``attribute_accuracy.png`` — per-group member vs non-member accuracy
        - ``summary.txt`` — human-readable text summary

        Parameters
        ----------
        output_dir:
            Directory where all report files are saved.

        Returns
        -------
        Path
            The output directory.
        """
        if self.result is None:
            raise RuntimeError("Call run() before generate_report().")

        from auditml.attacks.visualization import (
            plot_attribute_accuracy,
            plot_per_class_metrics,
            plot_roc_curve,
            plot_score_distributions,
        )

        out = Path(output_dir)
        out.mkdir(parents=True, exist_ok=True)

        # 1. Overall metrics
        metrics = self.evaluate()
        with open(out / "metrics.json", "w") as f:
            json.dump(metrics, f, indent=2)

        # 2. Per-class metrics
        per_class = self.evaluate_per_class()
        per_class_str = {str(k): v for k, v in per_class.items()}
        with open(out / "per_class_metrics.json", "w") as f:
            json.dump(per_class_str, f, indent=2)

        # 3. Per-group attribute accuracy
        per_group = self.evaluate_per_group()
        with open(out / "per_group_accuracy.json", "w") as f:
            serialised = {k: {str(g): v for g, v in d.items()} for k, d in per_group.items()}
            json.dump(serialised, f, indent=2)

        # 4. ROC curve
        plot_roc_curve(
            ground_truth=self.result.ground_truth,
            confidence_scores=self.result.confidence_scores,
            title="ROC Curve — Attribute Inference Attack",
            save_path=out / "roc_curve.png",
        )

        # 5. Confidence distribution histogram
        plot_score_distributions(
            member_scores=self.member_confidence,
            nonmember_scores=self.nonmember_confidence,
            metric_name="attribute confidence",
            save_path=out / "confidence_distributions.png",
            title="Attribute Confidence Distribution — Members vs Non-Members",
        )

        # 6. Per-class accuracy bar chart
        plot_per_class_metrics(
            per_class_metrics=per_class,
            save_path=out / "per_class_accuracy.png",
        )

        # 7. Per-group attribute accuracy comparison
        plot_attribute_accuracy(
            member_accuracy=per_group["member"],
            nonmember_accuracy=per_group["nonmember"],
            save_path=out / "attribute_accuracy.png",
        )

        # 8. Summary text
        self._write_summary(out / "summary.txt", metrics, per_class, per_group)

        return out

    def _write_summary(
        self,
        path: Path,
        metrics: dict[str, float],
        per_class: dict[int, dict[str, float]],
        per_group: dict[str, dict[int, float]],
    ) -> None:
        """Write a human-readable text summary of the attack results."""
        lines = [
            "=" * 60,
            "AuditML — Attribute Inference Attack Report",
            "=" * 60,
            "",
            f"Sensitive attribute: {self.sensitive_attribute}",
            f"Number of groups:    {self.num_groups}",
            f"Total samples:       {len(self.result.predictions)}",
            f"  Members:           {int(self.result.ground_truth.sum())}",
            f"  Non-members:       {int((1 - self.result.ground_truth).sum())}",
            "",
            "--- Overall Metrics ---",
        ]
        for key, val in metrics.items():
            lines.append(f"  {key:<20s}: {val:.4f}")

        lines.append("")
        lines.append("--- Per-Group Attribute Accuracy ---")
        all_groups = sorted(
            set(per_group.get("member", {}).keys()) | set(per_group.get("nonmember", {}).keys())
        )
        for g in all_groups:
            mem_acc = per_group.get("member", {}).get(g, 0.0)
            nonmem_acc = per_group.get("nonmember", {}).get(g, 0.0)
            gap = mem_acc - nonmem_acc
            lines.append(
                f"  Group {g:>3d}:  member={mem_acc:.3f}  "
                f"non-member={nonmem_acc:.3f}  gap={gap:+.3f}"
            )

        lines.append("")
        lines.append("--- Per-Class Membership Accuracy ---")
        for cls in sorted(per_class.keys()):
            m = per_class[cls]
            lines.append(
                f"  Class {cls:>3d}:  acc={m['accuracy']:.3f}  "
                f"auc={m['auc_roc']:.3f}  n={m['n_samples']}"
            )

        lines.append("")
        lines.append("--- Metadata ---")
        for key, val in self.result.metadata.items():
            if not key.startswith("_"):
                lines.append(f"  {key}: {val}")

        lines.append("")
        path.write_text("\n".join(lines))

run(member_loader: DataLoader, nonmember_loader: DataLoader) -> AttackResult

Execute the attribute inference attack.

Steps:

  1. Extract the target model's softmax outputs for all samples.
  2. Map class labels → sensitive attribute (group labels).
  3. Train an attack MLP on member outputs → group.
  4. Score every sample by the confidence of the correct group prediction. Members should score higher.

Parameters:

Name Type Description Default
member_loader DataLoader

DataLoader over training (member) samples.

required
nonmember_loader DataLoader

DataLoader over non-member samples.

required

Returns:

Type Description
AttackResult
Source code in src/auditml/attacks/attribute_inference.py
def run(
    self,
    member_loader: DataLoader,
    nonmember_loader: DataLoader,
) -> AttackResult:
    """Execute the attribute inference attack.

    Steps:

    1. Extract the target model's softmax outputs for all samples.
    2. Map class labels → sensitive attribute (group labels).
    3. Train an attack MLP on *member* outputs → group.
    4. Score every sample by the confidence of the correct group
       prediction.  Members should score higher.

    Parameters
    ----------
    member_loader:
        DataLoader over training (member) samples.
    nonmember_loader:
        DataLoader over non-member samples.

    Returns
    -------
    AttackResult
    """
    # Step 1: Extract model outputs
    member_probs, _, member_true_labels = self.get_model_outputs(member_loader)
    nonmember_probs, _, nonmember_true_labels = self.get_model_outputs(nonmember_loader)

    self.member_labels = member_true_labels
    self.nonmember_labels = nonmember_true_labels

    # Step 2: Sensitive attribute labels (group assignments)
    member_groups = self._labels_to_groups(member_true_labels)
    nonmember_groups = self._labels_to_groups(nonmember_true_labels)

    logger.info(
        "Training attribute attack model: %d groups, %d member samples",
        self.num_groups, len(member_probs),
    )

    # Step 3: Train attack model on member data
    self.attack_model = self._train_attack_model(member_probs, member_groups)

    # Step 4: Predict attribute confidence for both sets
    member_attr_conf = self._predict_attribute_confidence(
        member_probs, member_groups,
    )
    nonmember_attr_conf = self._predict_attribute_confidence(
        nonmember_probs, nonmember_groups,
    )

    self.member_confidence = member_attr_conf
    self.nonmember_confidence = nonmember_attr_conf

    logger.info(
        "Attribute confidence — members: %.4f, non-members: %.4f",
        float(member_attr_conf.mean()), float(nonmember_attr_conf.mean()),
    )

    # Step 5: Build membership inference signal
    ground_truth = np.concatenate([
        np.ones(len(member_attr_conf)),
        np.zeros(len(nonmember_attr_conf)),
    ])
    all_scores = np.concatenate([member_attr_conf, nonmember_attr_conf])

    # Threshold at median for binary predictions
    threshold = float(np.median(all_scores))
    predictions = (all_scores >= threshold).astype(np.int32)

    self.result = AttackResult(
        predictions=predictions,
        ground_truth=ground_truth,
        confidence_scores=all_scores,
        attack_name=self.attack_name,
        metadata={
            "num_groups": self.num_groups,
            "sensitive_attribute": self.sensitive_attribute,
            "mean_member_attr_confidence": float(member_attr_conf.mean()),
            "mean_nonmember_attr_confidence": float(nonmember_attr_conf.mean()),
            # Store probs for per-group evaluation (prefixed with _ to
            # exclude from summary text)
            "_member_probs": member_probs,
            "_nonmember_probs": nonmember_probs,
        },
    )
    return self.result

predict_attributes(probs: np.ndarray) -> np.ndarray

Predict the sensitive attribute for each sample.

Parameters:

Name Type Description Default
probs ndarray

Softmax outputs from the target model, shape (N, C).

required

Returns:

Type Description
ndarray

Predicted group IDs, shape (N,).

Source code in src/auditml/attacks/attribute_inference.py
@torch.no_grad()
def predict_attributes(self, probs: np.ndarray) -> np.ndarray:
    """Predict the sensitive attribute for each sample.

    Parameters
    ----------
    probs:
        Softmax outputs from the target model, shape ``(N, C)``.

    Returns
    -------
    np.ndarray
        Predicted group IDs, shape ``(N,)``.
    """
    if self.attack_model is None:
        raise RuntimeError("Call run() before predict_attributes().")

    self.attack_model.eval()
    x = torch.tensor(probs, dtype=torch.float32).to(self.device)
    logits = self.attack_model(x)
    return logits.argmax(dim=1).cpu().numpy()

evaluate_per_class() -> dict[int, dict[str, float]]

Compute evaluation metrics separately for each original class.

Groups all samples by their original class label and computes the full metric suite for each class. This reveals which classes are most vulnerable to the attribute inference attack.

Returns:

Type Description
dict[int, dict[str, float]]

Mapping from class label to metric dictionary.

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/attribute_inference.py
def evaluate_per_class(self) -> dict[int, dict[str, float]]:
    """Compute evaluation metrics separately for each original class.

    Groups all samples by their original class label and computes the
    full metric suite for each class.  This reveals which classes are
    most vulnerable to the attribute inference attack.

    Returns
    -------
    dict[int, dict[str, float]]
        Mapping from class label to metric dictionary.

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None:
        raise RuntimeError("Call run() before evaluate_per_class().")

    all_labels = np.concatenate([self.member_labels, self.nonmember_labels])
    unique_classes = np.unique(all_labels)

    per_class: dict[int, dict[str, float]] = {}
    for cls in unique_classes:
        mask = all_labels == cls
        preds_cls = self.result.predictions[mask]
        gt_cls = self.result.ground_truth[mask]
        scores_cls = self.result.confidence_scores[mask]

        if len(gt_cls) < 2 or len(np.unique(gt_cls)) < 2:
            per_class[int(cls)] = {
                "accuracy": float(np.mean(preds_cls == gt_cls)) if len(gt_cls) > 0 else 0.0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0,
                "auc_roc": 0.0,
                "auc_pr": 0.0,
                "tpr_at_1fpr": 0.0,
                "tpr_at_01fpr": 0.0,
                "n_samples": int(mask.sum()),
            }
            continue

        metrics = self._compute_metrics(preds_cls, gt_cls, scores_cls)
        metrics["n_samples"] = int(mask.sum())
        per_class[int(cls)] = metrics

    return per_class

evaluate_per_group() -> dict[str, dict[int, float]]

Compute attribute prediction accuracy for each group.

Returns separate accuracy dictionaries for members and non-members. A gap between the two signals privacy leakage.

Returns:

Type Description
dict with keys ``"member"`` and ``"nonmember"``, each mapping
group ID to prediction accuracy on that group.

Raises:

Type Description
RuntimeError

If run() has not been called yet.

Source code in src/auditml/attacks/attribute_inference.py
def evaluate_per_group(self) -> dict[str, dict[int, float]]:
    """Compute attribute prediction accuracy for each group.

    Returns separate accuracy dictionaries for members and
    non-members.  A gap between the two signals privacy leakage.

    Returns
    -------
    dict with keys ``"member"`` and ``"nonmember"``, each mapping
    group ID to prediction accuracy on that group.

    Raises
    ------
    RuntimeError
        If ``run()`` has not been called yet.
    """
    if self.result is None or self.attack_model is None:
        raise RuntimeError("Call run() before evaluate_per_group().")

    member_groups = self._labels_to_groups(self.member_labels)
    nonmember_groups = self._labels_to_groups(self.nonmember_labels)

    member_preds = self.predict_attributes(
        self._get_stored_probs("member"),
    )
    nonmember_preds = self.predict_attributes(
        self._get_stored_probs("nonmember"),
    )

    member_acc: dict[int, float] = {}
    for g in range(self.num_groups):
        mask = member_groups == g
        if mask.sum() > 0:
            member_acc[g] = float((member_preds[mask] == member_groups[mask]).mean())

    nonmember_acc: dict[int, float] = {}
    for g in range(self.num_groups):
        mask = nonmember_groups == g
        if mask.sum() > 0:
            nonmember_acc[g] = float((nonmember_preds[mask] == nonmember_groups[mask]).mean())

    return {"member": member_acc, "nonmember": nonmember_acc}

generate_report(output_dir: str | Path) -> Path

Generate a complete evaluation report with metrics and plots.

Creates the following files in output_dir:

  • metrics.json — overall evaluation metrics
  • per_class_metrics.json — per-class breakdown
  • per_group_accuracy.json — per-group attribute accuracy
  • roc_curve.png — ROC curve plot
  • confidence_distributions.png — member vs non-member histogram
  • per_class_accuracy.png — bar chart of per-class accuracy
  • attribute_accuracy.png — per-group member vs non-member accuracy
  • summary.txt — human-readable text summary

Parameters:

Name Type Description Default
output_dir str | Path

Directory where all report files are saved.

required

Returns:

Type Description
Path

The output directory.

Source code in src/auditml/attacks/attribute_inference.py
def generate_report(self, output_dir: str | Path) -> Path:
    """Generate a complete evaluation report with metrics and plots.

    Creates the following files in *output_dir*:

    - ``metrics.json`` — overall evaluation metrics
    - ``per_class_metrics.json`` — per-class breakdown
    - ``per_group_accuracy.json`` — per-group attribute accuracy
    - ``roc_curve.png`` — ROC curve plot
    - ``confidence_distributions.png`` — member vs non-member histogram
    - ``per_class_accuracy.png`` — bar chart of per-class accuracy
    - ``attribute_accuracy.png`` — per-group member vs non-member accuracy
    - ``summary.txt`` — human-readable text summary

    Parameters
    ----------
    output_dir:
        Directory where all report files are saved.

    Returns
    -------
    Path
        The output directory.
    """
    if self.result is None:
        raise RuntimeError("Call run() before generate_report().")

    from auditml.attacks.visualization import (
        plot_attribute_accuracy,
        plot_per_class_metrics,
        plot_roc_curve,
        plot_score_distributions,
    )

    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    # 1. Overall metrics
    metrics = self.evaluate()
    with open(out / "metrics.json", "w") as f:
        json.dump(metrics, f, indent=2)

    # 2. Per-class metrics
    per_class = self.evaluate_per_class()
    per_class_str = {str(k): v for k, v in per_class.items()}
    with open(out / "per_class_metrics.json", "w") as f:
        json.dump(per_class_str, f, indent=2)

    # 3. Per-group attribute accuracy
    per_group = self.evaluate_per_group()
    with open(out / "per_group_accuracy.json", "w") as f:
        serialised = {k: {str(g): v for g, v in d.items()} for k, d in per_group.items()}
        json.dump(serialised, f, indent=2)

    # 4. ROC curve
    plot_roc_curve(
        ground_truth=self.result.ground_truth,
        confidence_scores=self.result.confidence_scores,
        title="ROC Curve — Attribute Inference Attack",
        save_path=out / "roc_curve.png",
    )

    # 5. Confidence distribution histogram
    plot_score_distributions(
        member_scores=self.member_confidence,
        nonmember_scores=self.nonmember_confidence,
        metric_name="attribute confidence",
        save_path=out / "confidence_distributions.png",
        title="Attribute Confidence Distribution — Members vs Non-Members",
    )

    # 6. Per-class accuracy bar chart
    plot_per_class_metrics(
        per_class_metrics=per_class,
        save_path=out / "per_class_accuracy.png",
    )

    # 7. Per-group attribute accuracy comparison
    plot_attribute_accuracy(
        member_accuracy=per_group["member"],
        nonmember_accuracy=per_group["nonmember"],
        save_path=out / "attribute_accuracy.png",
    )

    # 8. Summary text
    self._write_summary(out / "summary.txt", metrics, per_class, per_group)

    return out

Visualisation helpers

auditml.attacks.visualization

Visualization utilities for AuditML attack results.

Provides reusable plotting functions that any attack can call. All functions optionally save to disk and return a matplotlib.figure.Figure so callers can further customise or display interactively.

The module uses the Agg backend by default so that plots can be generated on headless servers (e.g. Colab, CI) without requiring a display.

plot_roc_curve(ground_truth: np.ndarray, confidence_scores: np.ndarray, save_path: str | Path | None = None, title: str = 'ROC Curve — Membership Inference Attack') -> plt.Figure

Plot the Receiver Operating Characteristic curve.

The ROC curve shows the trade-off between True Positive Rate (TPR) and False Positive Rate (FPR) at every possible threshold. The Area Under the Curve (AUC) summarises overall attack effectiveness: 0.5 = random guessing, 1.0 = perfect attack.

Parameters:

Name Type Description Default
ground_truth ndarray

Binary array (1 = member, 0 = non-member).

required
confidence_scores ndarray

Continuous scores where higher = more likely member.

required
save_path str | Path | None

If given, the plot is saved to this path.

None
title str

Plot title.

'ROC Curve — Membership Inference Attack'

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_roc_curve(
    ground_truth: np.ndarray,
    confidence_scores: np.ndarray,
    save_path: str | Path | None = None,
    title: str = "ROC Curve — Membership Inference Attack",
) -> plt.Figure:
    """Plot the Receiver Operating Characteristic curve.

    The ROC curve shows the trade-off between True Positive Rate (TPR)
    and False Positive Rate (FPR) at every possible threshold.  The
    Area Under the Curve (AUC) summarises overall attack effectiveness:
    0.5 = random guessing, 1.0 = perfect attack.

    Parameters
    ----------
    ground_truth:
        Binary array (1 = member, 0 = non-member).
    confidence_scores:
        Continuous scores where higher = more likely member.
    save_path:
        If given, the plot is saved to this path.
    title:
        Plot title.

    Returns
    -------
    matplotlib.figure.Figure
    """
    fig, ax = plt.subplots(figsize=(7, 6))

    # Compute ROC
    fpr, tpr, _ = roc_curve(ground_truth, confidence_scores)
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve
    ax.plot(fpr, tpr, color="#2563eb", lw=2, label=f"AUC = {roc_auc:.4f}")

    # Random baseline
    ax.plot([0, 1], [0, 1], color="grey", lw=1, linestyle="--", label="Random (AUC = 0.5)")

    # Mark key FPR thresholds
    for target_fpr, marker, label in [
        (0.01, "o", "TPR @ 1% FPR"),
        (0.001, "s", "TPR @ 0.1% FPR"),
    ]:
        tpr_at_fpr = float(np.interp(target_fpr, fpr, tpr))
        ax.plot(target_fpr, tpr_at_fpr, marker=marker, markersize=8,
                label=f"{label} = {tpr_at_fpr:.3f}")

    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(title)
    ax.legend(loc="lower right")
    ax.set_xlim([-0.01, 1.01])
    ax.set_ylim([-0.01, 1.01])
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig

plot_score_distributions(member_scores: np.ndarray, nonmember_scores: np.ndarray, metric_name: str = 'loss', threshold: float | None = None, save_path: str | Path | None = None, title: str | None = None) -> plt.Figure

Plot overlapping histograms of member vs non-member scores.

This is the most intuitive visualisation for threshold MIA: if the two distributions overlap a lot, the attack is weak (can't tell members from non-members). If they're well-separated, the attack is strong.

Parameters:

Name Type Description Default
member_scores ndarray

Raw signal values for training members.

required
nonmember_scores ndarray

Raw signal values for non-members.

required
metric_name str

Name of the metric (for axis label).

'loss'
threshold float | None

If given, a vertical line is drawn at this value.

None
save_path str | Path | None

If given, saves the figure.

None
title str | None

Plot title. Auto-generated if None.

None

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_score_distributions(
    member_scores: np.ndarray,
    nonmember_scores: np.ndarray,
    metric_name: str = "loss",
    threshold: float | None = None,
    save_path: str | Path | None = None,
    title: str | None = None,
) -> plt.Figure:
    """Plot overlapping histograms of member vs non-member scores.

    This is the most intuitive visualisation for threshold MIA: if the
    two distributions overlap a lot, the attack is weak (can't tell
    members from non-members).  If they're well-separated, the attack
    is strong.

    Parameters
    ----------
    member_scores:
        Raw signal values for training members.
    nonmember_scores:
        Raw signal values for non-members.
    metric_name:
        Name of the metric (for axis label).
    threshold:
        If given, a vertical line is drawn at this value.
    save_path:
        If given, saves the figure.
    title:
        Plot title. Auto-generated if ``None``.

    Returns
    -------
    matplotlib.figure.Figure
    """
    fig, ax = plt.subplots(figsize=(8, 5))

    if title is None:
        title = f"Score Distribution — {metric_name.capitalize()} Metric"

    # Compute shared bin edges for fair comparison
    all_scores = np.concatenate([member_scores, nonmember_scores])
    bins = np.linspace(all_scores.min(), all_scores.max(), 50)

    ax.hist(member_scores, bins=bins, alpha=0.6, color="#2563eb",
            label=f"Members (n={len(member_scores)})", density=True)
    ax.hist(nonmember_scores, bins=bins, alpha=0.6, color="#dc2626",
            label=f"Non-members (n={len(nonmember_scores)})", density=True)

    if threshold is not None:
        ax.axvline(threshold, color="black", lw=2, linestyle="--",
                   label=f"Threshold = {threshold:.4f}")

    ax.set_xlabel(f"{metric_name.capitalize()} Score")
    ax.set_ylabel("Density")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig

plot_per_class_metrics(per_class_metrics: dict[int, dict[str, float]], metric_key: str = 'accuracy', save_path: str | Path | None = None, title: str | None = None) -> plt.Figure

Bar chart showing a chosen metric for each class.

Useful for identifying which classes are most vulnerable to membership inference (higher accuracy = more vulnerable).

Parameters:

Name Type Description Default
per_class_metrics dict[int, dict[str, float]]

Output of ThresholdMIA.evaluate_per_class().

required
metric_key str

Which metric to plot (default "accuracy").

'accuracy'
save_path str | Path | None

If given, saves the figure.

None
title str | None

Plot title.

None

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_per_class_metrics(
    per_class_metrics: dict[int, dict[str, float]],
    metric_key: str = "accuracy",
    save_path: str | Path | None = None,
    title: str | None = None,
) -> plt.Figure:
    """Bar chart showing a chosen metric for each class.

    Useful for identifying which classes are most vulnerable to
    membership inference (higher accuracy = more vulnerable).

    Parameters
    ----------
    per_class_metrics:
        Output of ``ThresholdMIA.evaluate_per_class()``.
    metric_key:
        Which metric to plot (default ``"accuracy"``).
    save_path:
        If given, saves the figure.
    title:
        Plot title.

    Returns
    -------
    matplotlib.figure.Figure
    """
    if title is None:
        title = f"Per-Class Attack {metric_key.replace('_', ' ').title()}"

    classes = sorted(per_class_metrics.keys())
    values = [per_class_metrics[c][metric_key] for c in classes]
    n_samples = [per_class_metrics[c].get("n_samples", 0) for c in classes]

    fig, ax = plt.subplots(figsize=(max(8, len(classes) * 0.5), 5))

    bars = ax.bar(range(len(classes)), values, color="#2563eb", alpha=0.8)

    # Add sample counts above each bar
    for i, (bar, n) in enumerate(zip(bars, n_samples)):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                f"n={n}", ha="center", va="bottom", fontsize=7)

    # Draw random baseline at 0.5
    ax.axhline(0.5, color="grey", linestyle="--", lw=1, alpha=0.7,
               label="Random baseline (0.5)")

    ax.set_xticks(range(len(classes)))
    ax.set_xticklabels([str(c) for c in classes], fontsize=8)
    ax.set_xlabel("Class")
    ax.set_ylabel(metric_key.replace("_", " ").title())
    ax.set_title(title)
    ax.set_ylim([0, 1.1])
    ax.legend()
    ax.grid(True, axis="y", alpha=0.3)

    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig

plot_reconstructions(reconstructions: dict[int, np.ndarray], confidences: dict[int, float] | None = None, save_path: str | Path | None = None, title: str = 'Model Inversion — Reconstructed Images') -> plt.Figure

Display reconstructed images in a grid, one per class.

This is the key visual for model inversion: if the images look like recognisable digits/objects, the model has leaked training data.

Parameters:

Name Type Description Default
reconstructions dict[int, ndarray]

Mapping from class label to image array of shape (1, C, H, W) or (C, H, W).

required
confidences dict[int, float] | None

Optional mapping from class label to reconstruction confidence. Displayed below each image.

None
save_path str | Path | None

If given, saves the figure.

None
title str

Plot title.

'Model Inversion — Reconstructed Images'

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_reconstructions(
    reconstructions: dict[int, np.ndarray],
    confidences: dict[int, float] | None = None,
    save_path: str | Path | None = None,
    title: str = "Model Inversion — Reconstructed Images",
) -> plt.Figure:
    """Display reconstructed images in a grid, one per class.

    This is the key visual for model inversion: if the images look like
    recognisable digits/objects, the model has leaked training data.

    Parameters
    ----------
    reconstructions:
        Mapping from class label to image array of shape
        ``(1, C, H, W)`` or ``(C, H, W)``.
    confidences:
        Optional mapping from class label to reconstruction confidence.
        Displayed below each image.
    save_path:
        If given, saves the figure.
    title:
        Plot title.

    Returns
    -------
    matplotlib.figure.Figure
    """
    classes = sorted(reconstructions.keys())
    n = len(classes)
    cols = min(n, 5)
    rows = (n + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3.5 * rows))
    if n == 1:
        axes = np.array([axes])
    axes = np.atleast_2d(axes)

    for idx, cls in enumerate(classes):
        row, col = divmod(idx, cols)
        ax = axes[row, col]

        img = reconstructions[cls]
        # Handle (1, C, H, W) or (C, H, W) shapes
        if img.ndim == 4:
            img = img[0]  # remove batch dim -> (C, H, W)

        if img.shape[0] == 1:
            # Grayscale: (1, H, W) -> (H, W)
            ax.imshow(img[0], cmap="gray")
        elif img.shape[0] == 3:
            # RGB: (C, H, W) -> (H, W, C), normalise to [0, 1]
            img_hwc = np.transpose(img, (1, 2, 0))
            img_hwc = np.clip(
                (img_hwc - img_hwc.min()) / (img_hwc.max() - img_hwc.min() + 1e-8),
                0, 1,
            )
            ax.imshow(img_hwc)
        else:
            ax.imshow(img[0], cmap="gray")

        label = f"Class {cls}"
        if confidences and cls in confidences:
            label += f"\nconf={confidences[cls]:.3f}"
        ax.set_title(label, fontsize=9)
        ax.axis("off")

    # Hide unused subplots
    for idx in range(n, rows * cols):
        row, col = divmod(idx, cols)
        axes[row, col].axis("off")

    fig.suptitle(title, fontsize=13, fontweight="bold")
    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig

plot_reconstruction_confidence(confidences: dict[int, float], save_path: str | Path | None = None, title: str = 'Reconstruction Confidence per Class') -> plt.Figure

Bar chart of model confidence on each reconstructed image.

Higher confidence means the optimisation was more successful at producing an image the model strongly associates with that class.

Parameters:

Name Type Description Default
confidences dict[int, float]

Mapping from class label to confidence (softmax probability).

required
save_path str | Path | None

If given, saves the figure.

None
title str

Plot title.

'Reconstruction Confidence per Class'

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_reconstruction_confidence(
    confidences: dict[int, float],
    save_path: str | Path | None = None,
    title: str = "Reconstruction Confidence per Class",
) -> plt.Figure:
    """Bar chart of model confidence on each reconstructed image.

    Higher confidence means the optimisation was more successful at
    producing an image the model strongly associates with that class.

    Parameters
    ----------
    confidences:
        Mapping from class label to confidence (softmax probability).
    save_path:
        If given, saves the figure.
    title:
        Plot title.

    Returns
    -------
    matplotlib.figure.Figure
    """
    classes = sorted(confidences.keys())
    values = [confidences[c] for c in classes]

    fig, ax = plt.subplots(figsize=(max(8, len(classes) * 0.5), 5))
    bars = ax.bar(range(len(classes)), values, color="#7c3aed", alpha=0.8)

    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                f"{val:.3f}", ha="center", va="bottom", fontsize=8)

    ax.set_xticks(range(len(classes)))
    ax.set_xticklabels([str(c) for c in classes], fontsize=9)
    ax.set_xlabel("Class")
    ax.set_ylabel("Confidence (Softmax Probability)")
    ax.set_title(title)
    ax.set_ylim([0, 1.1])
    ax.grid(True, axis="y", alpha=0.3)

    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig

plot_attribute_accuracy(member_accuracy: dict[int, float], nonmember_accuracy: dict[int, float], save_path: str | Path | None = None, title: str = 'Attribute Prediction Accuracy — Members vs Non-Members') -> plt.Figure

Side-by-side bar chart comparing attribute accuracy for members and non-members.

A large gap between member and non-member accuracy indicates that the model's outputs reveal more about the sensitive attribute for training data — a privacy leak.

Parameters:

Name Type Description Default
member_accuracy dict[int, float]

Mapping from group ID to attribute prediction accuracy on members.

required
nonmember_accuracy dict[int, float]

Mapping from group ID to attribute prediction accuracy on non-members.

required
save_path str | Path | None

If given, saves the figure.

None
title str

Plot title.

'Attribute Prediction Accuracy — Members vs Non-Members'

Returns:

Type Description
Figure
Source code in src/auditml/attacks/visualization.py
def plot_attribute_accuracy(
    member_accuracy: dict[int, float],
    nonmember_accuracy: dict[int, float],
    save_path: str | Path | None = None,
    title: str = "Attribute Prediction Accuracy — Members vs Non-Members",
) -> plt.Figure:
    """Side-by-side bar chart comparing attribute accuracy for members and non-members.

    A large gap between member and non-member accuracy indicates that the
    model's outputs reveal more about the sensitive attribute for training
    data — a privacy leak.

    Parameters
    ----------
    member_accuracy:
        Mapping from group ID to attribute prediction accuracy on members.
    nonmember_accuracy:
        Mapping from group ID to attribute prediction accuracy on non-members.
    save_path:
        If given, saves the figure.
    title:
        Plot title.

    Returns
    -------
    matplotlib.figure.Figure
    """
    groups = sorted(set(member_accuracy.keys()) | set(nonmember_accuracy.keys()))
    mem_vals = [member_accuracy.get(g, 0.0) for g in groups]
    nonmem_vals = [nonmember_accuracy.get(g, 0.0) for g in groups]

    x = np.arange(len(groups))
    width = 0.35

    fig, ax = plt.subplots(figsize=(max(8, len(groups) * 0.8), 5))
    bars_mem = ax.bar(x - width / 2, mem_vals, width, label="Members",
                      color="#2563eb", alpha=0.8)
    bars_non = ax.bar(x + width / 2, nonmem_vals, width, label="Non-members",
                      color="#dc2626", alpha=0.8)

    # Add value labels above bars
    for bar in bars_mem:
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                f"{bar.get_height():.2f}", ha="center", va="bottom", fontsize=7)
    for bar in bars_non:
        ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                f"{bar.get_height():.2f}", ha="center", va="bottom", fontsize=7)

    ax.axhline(1.0 / max(len(groups), 1), color="grey", linestyle="--", lw=1,
               alpha=0.7, label="Random baseline")

    ax.set_xticks(x)
    ax.set_xticklabels([f"Group {g}" for g in groups], fontsize=8)
    ax.set_xlabel("Sensitive Attribute Group")
    ax.set_ylabel("Prediction Accuracy")
    ax.set_title(title)
    ax.set_ylim([0, 1.15])
    ax.legend()
    ax.grid(True, axis="y", alpha=0.3)

    fig.tight_layout()
    if save_path is not None:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    return fig