Workflows API

The low-level KonfAI workflows are exposed through three entrypoint functions and three root classes:

  • training

  • prediction

  • evaluation

These are the main Python APIs behind the konfai CLI.

Training

konfai.trainer.train(command=State.TRAIN, overwrite=False, model=None, gpu=[], cpu=None, quiet=False, tensorboard=False, config=PosixPath('Config.yml'), checkpoints_dir=PosixPath('Checkpoints'), statistics_dir=PosixPath('Statistics'))[source]

Build and execute the configured training workflow.

This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to build_train().

Return type:

DistributedObject

class konfai.trainer.Trainer(model=<konfai.network.network.ModelLoader object>, dataset={'dataset_filenames': ['default|./Dataset:mha'], 'groups_src': {'default|Labels': {'default|Labels': {'transforms': [], 'patch_transforms': []}}}, 'patch': <konfai.data.patching.DatasetPatch object>, 'use_cache': True, 'subset': <konfai.data.data_manager.TrainSubset object>, 'batch_size': 1, 'validation': 0.2, 'inline_augmentations': False, 'data_augmentations_list': {'DataAugmentation_0': <konfai.data.augmentation.DataAugmentationsList object>}}, train_name='default|TRAIN_01', manual_seed=None, epochs=100, it_validation=None, it_lr_update=None, autocast=False, gradient_checkpoints=None, gpu_checkpoints=None, ema_decay=0, data_log=None, early_stopping=None, save_checkpoint_mode='BEST')[source]

Bases: DistributedObject

Public API for training a model using the KonfAI framework. Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.

Main responsibilities: - Initialization from config (via @config) - Model and EMA setup - Checkpoint loading and saving - Distributed setup and launch

Parameters:
  • model (ModelLoader) – Loader for model architecture.

  • dataset (DataTrain) – Training/validation dataset.

  • train_name (str) – Training session name.

  • manual_seed (int | None) – Random seed.

  • epochs (int) – Number of epochs to run.

  • it_validation (int | None) – Validation interval.

  • it_lr_update (int | None) – Learning rate update interval.

  • autocast (bool) – Enable AMP training.

  • gradient_checkpoints (list[str] | None) – Modules to use gradient checkpointing on.

  • gpu_checkpoints (list[str] | None) – Modules to pin on specific GPUs.

  • ema_decay (float) – EMA decay factor.

  • data_log (list[str] | None) – Logging instructions.

  • early_stopping (EarlyStopping | None) – Optional early stopping config.

  • save_checkpoint_mode (str) – Either “BEST” or “ALL”.

setup(world_size)[source]

Initializes the training environment: - Clears previous outputs (unless resuming) - Initializes model and EMA - Loads checkpoint (if resuming) - Prepares dataloaders

Parameters:

world_size (int) – Total number of distributed processes.

run_process(world_size, global_rank, local_rank, dataloaders)[source]

Launches the actual training process via internal _Trainer class. Wraps model with DDP or CPU fallback, attaches EMA, and starts training.

Parameters:
  • world_size (int) – Total number of distributed processes.

  • global_rank (int) – Global rank of the current process.

  • local_rank (int) – Local rank within the node.

  • dataloaders (list[DataLoader]) – Training and validation dataloaders.

class konfai.trainer.EarlyStopping(monitor=None, patience=10, min_delta=0.0, mode='min')[source]

Bases: EarlyStoppingBase

Implements early stopping logic with configurable patience and monitored metrics.

monitor

Metrics to monitor.

Type:

list[str]

patience

Number of checks with no improvement before stopping.

Type:

int

min_delta

Minimum change to qualify as improvement.

Type:

float

mode

“min” or “max” depending on optimization direction.

Type:

str

Prediction

konfai.predictor.predict(models, overwrite=False, gpu=[], cpu=1, quiet=False, tb=False, prediction_file=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Prediction.yml'), predictions_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Predictions'))[source]

Build and execute the configured prediction workflow.

This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to build_predict().

Return type:

DistributedObject

class konfai.predictor.Predictor(model=<konfai.network.network.ModelLoader object>, dataset={'dataset_filenames': ['default|./Dataset'], 'groups_src': {'default': {'default|Labels': {'transforms': [], 'patch_transforms': []}}}, 'patch': <konfai.data.patching.DatasetPatch object>, 'use_cache': False, 'subset': <konfai.data.data_manager.PredictionSubset object>, 'batch_size': 1, 'validation': None, 'inline_augmentations': False, 'data_augmentations_list': {'DataAugmentation_0': <konfai.data.augmentation.DataAugmentationsList object>}}, combine='Mean', train_name='name', manual_seed=None, gpu_checkpoints=None, autocast=False, outputs_dataset={'default|Default': <konfai.predictor.OutputDatasetLoader object>}, data_log=None)[source]

Bases: DistributedObject

KonfAI’s main prediction controller.

This class orchestrates the prediction phase by: - Loading model weights from checkpoint(s) or URL(s) - Preparing datasets and output configurations - Managing distributed inference with optional multi-GPU support - Applying transformations and saving predictions - Optionally logging results to TensorBoard

model

The neural network model to use for prediction.

Type:

Network

dataset

Dataset manager for prediction data.

Type:

DataPrediction

combine_classpath

Path to the reduction strategy (e.g., “Mean”).

Type:

str

autocast

Whether to enable AMP inference.

Type:

bool

outputs_dataset

Mapping from layer names to output writers.

Type:

dict[str, OutputDataset]

data_log

List of tensors to log during inference.

Type:

list[str] | None

setup(world_size)[source]

Set up the predictor for inference.

This method performs all necessary initialization steps before running predictions: - Ensures output directories exist, and optionally prompts the user before overwriting existing predictions. - Copies the current configuration file (Prediction.yml) into the output directory for reproducibility. - Dynamically loads pretrained weights from local files or remote URLs. - Wraps the base model into a ModelComposite to support ensemble inference. - Initializes the prediction dataloader, with proper distribution across available GPUs.

Parameters:

world_size (int) – Total number of processes or GPUs used for distributed prediction.

run_process(world_size, global_rank, local_rank, dataloaders)[source]

Launch prediction on the given process rank.

Parameters:
  • world_size (int) – Total number of processes.

  • global_rank (int) – Rank of the current process.

  • local_rank (int) – Local device rank.

  • dataloaders (list[DataLoader]) – List of data loaders for prediction.

class konfai.predictor.OutputDataset(filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patch_combine, reduction)[source]

Bases: Dataset, NeedDevice, ABC

Abstract prediction sink that accumulates model outputs and writes them to disk.

Concrete subclasses define how layers are accumulated across patches, augmentations, and multiple models before the final prediction volume is materialized.

class konfai.predictor.OutputDatasetLoader(name_class='OutSameAsGroupDataset')[source]

Bases: object

Factory that instantiates output dataset classes from predictor config.

Evaluation

konfai.evaluator.evaluate(overwrite=False, gpu=[], cpu=1, quiet=False, tb=False, evaluations_file=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Evaluation.yml'), evaluations_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Evaluations'))[source]

Build and execute the configured evaluation workflow.

This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to build_evaluate().

Return type:

DistributedObject

class konfai.evaluator.Evaluator(train_name='default|TRAIN_01', metrics={'default': <konfai.evaluator.TargetCriterionsLoader object>}, dataset={'dataset_filenames': ['default|./Dataset:mha'], 'groups_src': {'default': {'default|group_dest': {'transforms': [], 'patch_transforms': []}}}, 'patch': None, 'use_cache': True, 'subset': <konfai.data.data_manager.PredictionSubset object>, 'batch_size': 1, 'validation': None, 'inline_augmentations': False, 'data_augmentations_list': {}})[source]

Bases: DistributedObject

Distributed evaluation engine for computing metrics on model predictions.

This class handles the evaluation of predicted outputs using predefined metric loaders. It supports multi-output and multi-target configurations, computes aggregated statistics across training and validation datasets, and synchronizes results across processes.

Evaluation results are stored in JSON format and optionally displayed during iteration.

Parameters:
  • train_name (str) – Unique name of the evaluation run, used for logging and output folders.

  • metrics (dict[str, TargetCriterionsLoader]) – Dictionary mapping output groups to loaders of target metrics.

  • dataset (DataMetric) – Dataset provider configured for evaluation mode.

statistics_train

Object used to store training evaluation metrics.

Type:

Statistics

statistics_validation

Object used to store validation evaluation metrics.

Type:

Statistics

dataloader

DataLoaders for training and validation sets.

Type:

list[DataLoader]

metric_path

Path to the evaluation output directory.

Type:

str

metrics

Instantiated metrics organized by output and target groups.

Type:

dict

setup(world_size)[source]

Prepare the evaluator for distributed metric computation.

This method performs the following steps: - Checks whether previous evaluation results exist and optionally overwrites them. - Creates the output directory and copies the current configuration file for reproducibility. - Loads the evaluation dataset according to the world size.

Parameters:

world_size (int) – Number of processes in the distributed evaluation setup.

update(batch_sample, statistics)[source]

Compute metrics for a batch and update running statistics.

Parameters:
  • batch_sample (dict[str, BatchDataItem]) – The batch sample object containing tensors and their metadata.

  • statistics (Statistics) – The statistics object to update (train or validation).

Returns:

Dictionary of computed metric values with keys in the format

’output_group:target_group:MetricName’.

Return type:

dict[str, float]

run_process(world_size, global_rank, gpu, dataloaders)[source]

Execute the distributed evaluation loop over the training and validation datasets.

This method iterates through the provided DataLoaders (train and optionally validation), updates the metric statistics using the configured metrics dictionary, and synchronizes the results across all processes. On the global rank 0, the metrics are saved as JSON files.

Metrics are displayed in real-time using tqdm progress bars, showing a summary of the current batch’s computed values.

Parameters:
  • world_size (int) – Total number of distributed processes.

  • global_rank (int) – Global rank of the current process (used for writing results).

  • gpu (int) – Local GPU ID used for synchronization.

  • dataloaders (list[DataLoader]) – A list containing one or two DataLoaders: - dataloaders[0] is used for training evaluation. - dataloaders[1] (optional) is used for validation evaluation.

Notes

  • Only the main process (global_rank == 0) writes final results to disk.

class konfai.evaluator.Statistics(filename)[source]

Bases: object

Utility class to accumulate, structure, and write evaluation metric results.

This class is used to: - Collect metrics for each dataset sample. - Compute aggregate statistics (mean, std, percentiles, etc.). - Export all results in a structured JSON format, including both per-case and aggregate values.

Parameters:

filename (Path) – Path to the output JSON file that will store the final results.

add(values, name_dataset)[source]

Add a set of metric values for a given dataset case.

Parameters:
  • values (dict[str, float]) – Dictionary of metric names and their values.

  • name_dataset (str) – Identifier (e.g., case name) for the sample.

Return type:

None

static get_statistic(values)[source]

Compute statistical aggregates for a list of metric values.

Parameters:

values (list[float]) – Values to summarize.

Returns:

A dictionary containing:
  • max, min, std

  • 25th, 50th, and 75th percentiles

  • mean and count

Return type:

dict[str, float]

write(outputs)[source]

Write the collected and aggregated statistics to the configured output file.

The output JSON structure contains: - case: All individual metrics per sample. - aggregates: Global statistics computed over all cases.

Parameters:

outputs (list[dict[str, dict[str, Any]]]) – List of metric dictionaries to merge and serialize.

Return type:

None

See also