Workflows API¶
The low-level KonfAI workflows are exposed through three entrypoint functions and three root classes:
training
prediction
evaluation
These are the main Python APIs behind the konfai CLI.
Training¶
- konfai.trainer.train(command=State.TRAIN, overwrite=False, model=None, gpu=[], cpu=None, quiet=False, tensorboard=False, config=PosixPath('Config.yml'), checkpoints_dir=PosixPath('Checkpoints'), statistics_dir=PosixPath('Statistics'))[source]
Build and execute the configured training workflow.
This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to
build_train().- Return type:
DistributedObject
- class konfai.trainer.Trainer(model=<konfai.network.network.ModelLoader object>, dataset={'dataset_filenames': ['default|./Dataset:mha'], 'groups_src': {'default|Labels': {'default|Labels': {'transforms': [], 'patch_transforms': []}}}, 'patch': <konfai.data.patching.DatasetPatch object>, 'use_cache': True, 'subset': <konfai.data.data_manager.TrainSubset object>, 'batch_size': 1, 'validation': 0.2, 'inline_augmentations': False, 'data_augmentations_list': {'DataAugmentation_0': <konfai.data.augmentation.DataAugmentationsList object>}}, train_name='default|TRAIN_01', manual_seed=None, epochs=100, it_validation=None, it_lr_update=None, autocast=False, gradient_checkpoints=None, gpu_checkpoints=None, ema_decay=0, data_log=None, early_stopping=None, save_checkpoint_mode='BEST')[source]
Bases:
DistributedObjectPublic API for training a model using the KonfAI framework. Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
Main responsibilities: - Initialization from config (via @config) - Model and EMA setup - Checkpoint loading and saving - Distributed setup and launch
- Parameters:
model (
ModelLoader) – Loader for model architecture.dataset (
DataTrain) – Training/validation dataset.train_name (
str) – Training session name.epochs (
int) – Number of epochs to run.autocast (
bool) – Enable AMP training.gradient_checkpoints (
list[str] |None) – Modules to use gradient checkpointing on.gpu_checkpoints (
list[str] |None) – Modules to pin on specific GPUs.ema_decay (
float) – EMA decay factor.early_stopping (
EarlyStopping|None) – Optional early stopping config.save_checkpoint_mode (
str) – Either “BEST” or “ALL”.
- setup(world_size)[source]
Initializes the training environment: - Clears previous outputs (unless resuming) - Initializes model and EMA - Loads checkpoint (if resuming) - Prepares dataloaders
- Parameters:
world_size (
int) – Total number of distributed processes.
- run_process(world_size, global_rank, local_rank, dataloaders)[source]
Launches the actual training process via internal _Trainer class. Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
- Parameters:
world_size (
int) – Total number of distributed processes.global_rank (
int) – Global rank of the current process.local_rank (
int) – Local rank within the node.dataloaders (
list[DataLoader]) – Training and validation dataloaders.
- class konfai.trainer.EarlyStopping(monitor=None, patience=10, min_delta=0.0, mode='min')[source]
Bases:
EarlyStoppingBaseImplements early stopping logic with configurable patience and monitored metrics.
- patience
Number of checks with no improvement before stopping.
- Type:
- min_delta
Minimum change to qualify as improvement.
- Type:
- mode
“min” or “max” depending on optimization direction.
- Type:
Prediction¶
- konfai.predictor.predict(models, overwrite=False, gpu=[], cpu=1, quiet=False, tb=False, prediction_file=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Prediction.yml'), predictions_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Predictions'))[source]
Build and execute the configured prediction workflow.
This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to
build_predict().- Return type:
DistributedObject
- class konfai.predictor.Predictor(model=<konfai.network.network.ModelLoader object>, dataset={'dataset_filenames': ['default|./Dataset'], 'groups_src': {'default': {'default|Labels': {'transforms': [], 'patch_transforms': []}}}, 'patch': <konfai.data.patching.DatasetPatch object>, 'use_cache': False, 'subset': <konfai.data.data_manager.PredictionSubset object>, 'batch_size': 1, 'validation': None, 'inline_augmentations': False, 'data_augmentations_list': {'DataAugmentation_0': <konfai.data.augmentation.DataAugmentationsList object>}}, combine='Mean', train_name='name', manual_seed=None, gpu_checkpoints=None, autocast=False, outputs_dataset={'default|Default': <konfai.predictor.OutputDatasetLoader object>}, data_log=None)[source]
Bases:
DistributedObjectKonfAI’s main prediction controller.
This class orchestrates the prediction phase by: - Loading model weights from checkpoint(s) or URL(s) - Preparing datasets and output configurations - Managing distributed inference with optional multi-GPU support - Applying transformations and saving predictions - Optionally logging results to TensorBoard
- model
The neural network model to use for prediction.
- Type:
- dataset
Dataset manager for prediction data.
- Type:
- combine_classpath
Path to the reduction strategy (e.g., “Mean”).
- Type:
- autocast
Whether to enable AMP inference.
- Type:
- outputs_dataset
Mapping from layer names to output writers.
- Type:
- setup(world_size)[source]
Set up the predictor for inference.
This method performs all necessary initialization steps before running predictions: - Ensures output directories exist, and optionally prompts the user before overwriting existing predictions. - Copies the current configuration file (Prediction.yml) into the output directory for reproducibility. - Dynamically loads pretrained weights from local files or remote URLs. - Wraps the base model into a ModelComposite to support ensemble inference. - Initializes the prediction dataloader, with proper distribution across available GPUs.
- Parameters:
world_size (
int) – Total number of processes or GPUs used for distributed prediction.
- run_process(world_size, global_rank, local_rank, dataloaders)[source]
Launch prediction on the given process rank.
- Parameters:
world_size (
int) – Total number of processes.global_rank (
int) – Rank of the current process.local_rank (
int) – Local device rank.dataloaders (
list[DataLoader]) – List of data loaders for prediction.
- class konfai.predictor.OutputDataset(filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patch_combine, reduction)[source]
Bases:
Dataset,NeedDevice,ABCAbstract prediction sink that accumulates model outputs and writes them to disk.
Concrete subclasses define how layers are accumulated across patches, augmentations, and multiple models before the final prediction volume is materialized.
Evaluation¶
- konfai.evaluator.evaluate(overwrite=False, gpu=[], cpu=1, quiet=False, tb=False, evaluations_file=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Evaluation.yml'), evaluations_dir=PosixPath('/home/docs/checkouts/readthedocs.org/user_builds/konfai/checkouts/latest/docs/source/Evaluations'))[source]
Build and execute the configured evaluation workflow.
This compatibility wrapper preserves the historical CLI-facing API while delegating the pure build step to
build_evaluate().- Return type:
DistributedObject
- class konfai.evaluator.Evaluator(train_name='default|TRAIN_01', metrics={'default': <konfai.evaluator.TargetCriterionsLoader object>}, dataset={'dataset_filenames': ['default|./Dataset:mha'], 'groups_src': {'default': {'default|group_dest': {'transforms': [], 'patch_transforms': []}}}, 'patch': None, 'use_cache': True, 'subset': <konfai.data.data_manager.PredictionSubset object>, 'batch_size': 1, 'validation': None, 'inline_augmentations': False, 'data_augmentations_list': {}})[source]
Bases:
DistributedObjectDistributed evaluation engine for computing metrics on model predictions.
This class handles the evaluation of predicted outputs using predefined metric loaders. It supports multi-output and multi-target configurations, computes aggregated statistics across training and validation datasets, and synchronizes results across processes.
Evaluation results are stored in JSON format and optionally displayed during iteration.
- Parameters:
train_name (
str) – Unique name of the evaluation run, used for logging and output folders.metrics (
dict[str,TargetCriterionsLoader]) – Dictionary mapping output groups to loaders of target metrics.dataset (
DataMetric) – Dataset provider configured for evaluation mode.
- statistics_train
Object used to store training evaluation metrics.
- Type:
- statistics_validation
Object used to store validation evaluation metrics.
- Type:
- dataloader
DataLoaders for training and validation sets.
- Type:
list[DataLoader]
- metric_path
Path to the evaluation output directory.
- Type:
- metrics
Instantiated metrics organized by output and target groups.
- Type:
- setup(world_size)[source]
Prepare the evaluator for distributed metric computation.
This method performs the following steps: - Checks whether previous evaluation results exist and optionally overwrites them. - Creates the output directory and copies the current configuration file for reproducibility. - Loads the evaluation dataset according to the world size.
- Parameters:
world_size (
int) – Number of processes in the distributed evaluation setup.
- update(batch_sample, statistics)[source]
Compute metrics for a batch and update running statistics.
- Parameters:
batch_sample (
dict[str,BatchDataItem]) – The batch sample object containing tensors and their metadata.statistics (
Statistics) – The statistics object to update (train or validation).
- Returns:
- Dictionary of computed metric values with keys in the format
’output_group:target_group:MetricName’.
- Return type:
- run_process(world_size, global_rank, gpu, dataloaders)[source]
Execute the distributed evaluation loop over the training and validation datasets.
This method iterates through the provided DataLoaders (train and optionally validation), updates the metric statistics using the configured metrics dictionary, and synchronizes the results across all processes. On the global rank 0, the metrics are saved as JSON files.
Metrics are displayed in real-time using tqdm progress bars, showing a summary of the current batch’s computed values.
- Parameters:
world_size (
int) – Total number of distributed processes.global_rank (
int) – Global rank of the current process (used for writing results).gpu (
int) – Local GPU ID used for synchronization.dataloaders (
list[DataLoader]) – A list containing one or two DataLoaders: - dataloaders[0] is used for training evaluation. - dataloaders[1] (optional) is used for validation evaluation.
Notes
Only the main process (global_rank == 0) writes final results to disk.
- class konfai.evaluator.Statistics(filename)[source]
Bases:
objectUtility class to accumulate, structure, and write evaluation metric results.
This class is used to: - Collect metrics for each dataset sample. - Compute aggregate statistics (mean, std, percentiles, etc.). - Export all results in a structured JSON format, including both per-case and aggregate values.
- Parameters:
filename (
Path) – Path to the output JSON file that will store the final results.
- add(values, name_dataset)[source]
Add a set of metric values for a given dataset case.
- static get_statistic(values)[source]
Compute statistical aggregates for a list of metric values.
- write(outputs)[source]
Write the collected and aggregated statistics to the configured output file.
The output JSON structure contains: - case: All individual metrics per sample. - aggregates: Global statistics computed over all cases.