Using custom models, transforms, augmentations, and losses¶
KonfAI custom objects are regular Python classes selected from YAML through
classpath and instantiated through apply_config().
To integrate cleanly, a custom object must satisfy two contracts:
the configuration contract: constructor argument names and YAML keys must match
the runtime contract: the class must inherit the right KonfAI base class and implement the expected methods
This page focuses on the custom object types that matter most in practice:
models
transforms
augmentations
losses and metrics
General rules¶
Keep project-specific Python files next to the YAML when possible.
Use explicit local classpaths such as
Model:UNetpp5orMyLoss:BoundaryDice.Keep constructor argument names in
snake_case.Do not use
@config()for local custom classes by default. It inserts an extra YAML subtree named after the class and makes custom configs harder to read and generate.Use
@config("...")only when you intentionally want a fixed explicit subtree name.Prefer concrete defaults for nested custom objects when they define the natural baseline behavior of the class. This makes the default wiring visible to KonfAI and helps YAML generation stay explicit.
Start from a shipped example and change one layer at a time.
For example, this pattern is usually a good fit:
class Gan(network.Network):
def __init__(
self,
generator: UNetpp5 = UNetpp5(),
discriminator: Discriminator = Discriminator(),
) -> None:
super().__init__()
self.add_module("Generator", generator)
self.add_module("Discriminator", discriminator)
This style is often preferable in KonfAI because the constructor already shows the effective default model or loss stack that will appear in YAML.
Use None defaults only when the nested object must be created dynamically,
depends on runtime information, or would be too expensive or stateful to
instantiate eagerly.
How classpath and @config(...) interact¶
classpath selects the Python implementation.
Example:
Trainer:
Model:
classpath: Model:UNetpp5
Without @config(...), a local class loaded through classpath reads its
constructor arguments directly from the current YAML branch.
For local custom classes, this is usually what you want:
Trainer:
Model:
classpath: Model:UNetpp5
outputs_criterions:
...
@config(...) is only needed when you deliberately want an extra nested
subtree.
In the current codebase:
@config("SomeKey")binds the class to theSomeKeysubtree@config()defaults to the class name
That means @config() on a local class loaded as Model:UNetpp5 would force a
less convenient YAML shape such as:
Trainer:
Model:
classpath: Model:UNetpp5
UNetpp5:
...
Prefer avoiding that implicit nesting unless you explicitly need it.
Contract summary¶
Custom object |
Recommended base class |
Required methods |
|---|---|---|
Model |
|
|
Transform |
|
|
Augmentation |
|
|
Loss / metric |
|
|
Custom models¶
For a real KonfAI model, inherit from
konfai.network.network.Network.
This is the right choice when you need:
a named graph built with
add_module(...)patch-aware inference
multiple outputs
outputs_criterionsfull compatibility with KonfAI training, prediction, and evaluation workflows
Minimal example:
import torch
from konfai.data.patching import ModelPatch
from konfai.network import blocks, network
class MySegNet(network.Network):
def __init__(
self,
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
schedulers: dict[str, network.LRSchedulersLoader] = {
"default|ReduceLROnPlateau": network.LRSchedulersLoader(0)
},
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {
"Argmax": network.TargetCriterionsLoader()
},
patch: ModelPatch | None = None,
dim: int = 2,
) -> None:
super().__init__(
in_channels=1,
optimizer=optimizer,
schedulers=schedulers,
outputs_criterions=outputs_criterions,
patch=patch,
dim=dim,
)
self.add_module("Backbone", torch.nn.Conv2d(1, 8, kernel_size=3, padding=1))
self.add_module("Head", torch.nn.Conv2d(8, 2, kernel_size=1))
self.add_module("Argmax", blocks.ArgMax(dim=1))
Matching YAML:
Trainer:
Model:
classpath: Model:MySegNet
outputs_criterions:
Argmax:
targets_criterions:
SEG:
criterions_loader:
Dice:
is_loss: true
group: 0
schedulers:
Constant:
nb_step: 0
value: 1
Model contract details¶
Call
super().__init__(...)so KonfAI can attach the optimizer, scheduler, patching, and criterion machinery.Build the graph with
self.add_module(...), not only with raw PyTorch attributes.The keys in
outputs_criterionsmust match the actual model output path, built fromadd_module(...)names.The output path is a graph name, not a dataset name. For example, in the built-in UNet config the key is
UNetBlock_0:Head:Argmax.
KonfAI can wrap a simpler module internally in some situations, but if you want
reliable custom behavior, inheriting from Network is the supported path.
Custom transforms¶
Use konfai.data.transform.Transform for one-way transforms and
TransformInverse when KonfAI must be able to invert the operation later.
The key methods are:
__call__(name, tensor, cache_attribute)to transform the tensortransform_shape(...)if the transform changes the tensor shapeinverse(...)if you inherit fromTransformInverse
cache_attribute is where you should save anything needed later by the inverse
transform.
Minimal invertible transform:
import torch
from konfai.data.transform import TransformInverse
from konfai.utils.dataset import Attribute
class Clamp01(TransformInverse):
def __init__(self, inverse: bool = False) -> None:
super().__init__(inverse)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
cache_attribute["original_min"] = tensor.min()
cache_attribute["original_max"] = tensor.max()
return tensor.clamp(0, 1)
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor
Matching YAML:
groups_dest:
CT:
transforms:
MyTransforms:Clamp01:
inverse: false
Transform contract details¶
The transform receives one tensor at a time.
nameis the case identifier.cache_attributestores per-case metadata and is the right place for values needed byinverse(...).self.datasetsis populated by KonfAI and can be used when the transform needs to read another group, as built-in transforms such asClipandStandardizedo.
Custom augmentations¶
Use konfai.data.augmentation.DataAugmentation for training-time
augmentations and test-time augmentation building blocks.
The augmentation contract has three stages:
_state_init(...)samples randomness once and can update the expected shapes_compute(...)applies the augmentation to the selected tensors_inverse(...)inverts the augmentation when needed
Minimal example:
import torch
from konfai.data.augmentation import DataAugmentation
from konfai.utils.dataset import Attribute
class AddNoise(DataAugmentation):
def __init__(self, sigma: float = 0.1, groups: list[str] | None = None) -> None:
super().__init__(groups=groups)
self.sigma = sigma
def _state_init(
self,
index: int,
shapes: list[list[int]],
caches_attribute: list[Attribute],
) -> list[list[int]]:
return shapes
def _compute(self, name: str, index: int, tensors: list[torch.Tensor]) -> list[torch.Tensor]:
return [tensor + torch.randn_like(tensor.float()) * self.sigma for tensor in tensors]
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
return tensor
Matching YAML:
augmentations:
DataAugmentation_0:
data_augmentations:
MyAugmentations:AddNoise:
sigma: 0.1
prob: 0.5
nb: 1
Augmentation contract details¶
Augmentations operate on a list of tensors, not on a single tensor.
problives in the same YAML branch as the augmentation-specific parameters._state_init(...)is the right place to sample random state that must stay consistent across groups.Implement
_inverse(...)even for a no-op, because KonfAI may call it during inverse augmentation workflows.
Custom losses and metrics¶
For custom criteria, use one of these base classes:
Criterionfor the common caseCriterionWithInitwhen the criterion needs access to the model graph before training startsCriterionWithAttributewhen the criterion needs per-sample attributes
Minimal loss:
import torch
from konfai.metric.measure import Criterion
class BoundaryMAE(Criterion):
def __init__(self, weight: float = 1.0) -> None:
super().__init__()
self.weight = weight
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
return self.weight * torch.nn.functional.l1_loss(output.float(), targets[0].float())
Matching YAML:
outputs_criterions:
Argmax:
targets_criterions:
SEG:
criterions_loader:
MyLosses:BoundaryMAE:
is_loss: true
group: 0
schedulers:
Constant:
nb_step: 0
value: 1
weight: 1.0
Loss contract details¶
forward(output, *targets)is the standard signature.CriterionWithAttributeusesforward(output, *targets, attributes=...).CriterionWithInitaddsinit(model, output_group, target_group).A criterion can return either a tensor or a tuple such as
(loss_tensor, scalar_value_for_logging).The YAML branch lives under
outputs_criterions -> <output_group> -> targets_criterions -> <target_group> -> criterions_loader.The same YAML branch can contain both KonfAI runtime fields such as
is_loss,group, andschedulers, and the constructor arguments of the criterion itself. Each object reads only the keys it needs.
Common failure modes¶
classpathimports the wrong file or class.Constructor argument names do not match YAML keys.
The object inherits from the wrong base class.
outputs_criterionspoints to a graph name that does not exist in the model.@config()inserted an extra class-name subtree, but the YAML was edited at the parent level.@config("...")points to a different subtree than the one you edited in YAML.
See also¶
- doc:
../concepts/configuration
- doc:
../concepts/model-graph
- doc:
../reference/api/extension-points
- doc:
../examples/synthesis