API Reference: Models¶
Base model¶
auditml.models.base.BaseModel
¶
Bases: ABC, Module
Base class that all AuditML target/shadow models must extend.
Subclasses must implement forward and get_features.
Source code in src/auditml/models/base.py
forward(x: torch.Tensor) -> torch.Tensor
abstractmethod
¶
get_features(x: torch.Tensor) -> torch.Tensor
abstractmethod
¶
Return the penultimate-layer feature vector.
Used by membership inference attacks to analyse the model's internal
representations. Shape: (batch, feature_dim).
Source code in src/auditml/models/base.py
CNN architectures¶
auditml.models.cnn
¶
Simple CNN architecture for MNIST / CIFAR-10 / CIFAR-100.
The architecture is intentionally compact — two conv blocks followed by two fully-connected layers. This mirrors common MIA literature baselines and trains quickly, which matters when training multiple shadow models.
SimpleCNN
¶
Bases: BaseModel
Configurable 2-block CNN.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_channels
|
int
|
Number of input channels (1 for MNIST, 3 for CIFAR). |
1
|
num_classes
|
int
|
Number of output classes. |
10
|
input_size
|
int
|
Spatial dimension of the input (28 for MNIST, 32 for CIFAR). |
28
|
feature_dim
|
int
|
Width of the penultimate fully-connected layer. |
128
|
Source code in src/auditml/models/cnn.py
ResNet¶
auditml.models.resnet
¶
Small ResNet variant for CIFAR datasets.
A compact ResNet-18-style architecture adapted for 32 × 32 inputs (no aggressive stem downsampling). Achieves stronger accuracy than SimpleCNN on CIFAR-10/100, useful for testing whether a more capable model leaks more privacy.
SmallResNet
¶
Bases: BaseModel
ResNet-18-style network for 32×32 inputs.
Uses GroupNorm instead of BatchNorm so the model is compatible with Opacus (differential-privacy training) out of the box.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_channels
|
int
|
Number of input channels (1 for MNIST, 3 for CIFAR). |
3
|
num_classes
|
int
|
Number of output classes. |
10
|
feature_dim
|
int
|
Width of the penultimate layer (default 512). |
512
|