# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Training workflow entrypoints and orchestration for KonfAI."""
import os
import shutil
from pathlib import Path
import torch
import torch.distributed as dist
import tqdm
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
from torch.optim.swa_utils import AveragedModel
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
from konfai import (
checkpoints_directory,
config_file,
cuda_visible_devices,
current_date,
konfai_state,
statistics_directory,
)
from konfai.data.data_manager import BatchSample, DataTrain
from konfai.network.network import Model, ModelLoader, NetState, Network
from konfai.utils.config import apply_config, config
from konfai.utils.errors import ConfigError, TrainerError
from konfai.utils.runtime import (
DataLog,
DistributedObject,
State,
configure_workflow_environment,
confirm_overwrite_or_raise,
description,
run_distributed_app,
)
[docs]
class EarlyStoppingBase:
"""Minimal protocol for early stopping strategies used by :class:`Trainer`."""
def __init__(self):
pass
[docs]
def is_stopped(self) -> bool:
return False
[docs]
def get_score(self, values: dict[str, float]):
return sum(list(values.values()))
def __call__(self, current_score: float) -> bool:
return False
[docs]
@config()
class EarlyStopping(EarlyStoppingBase):
"""
Implements early stopping logic with configurable patience and monitored metrics.
Attributes:
monitor (list[str]): Metrics to monitor.
patience (int): Number of checks with no improvement before stopping.
min_delta (float): Minimum change to qualify as improvement.
mode (str): "min" or "max" depending on optimization direction.
"""
def __init__(
self,
monitor: list[str] | None = None,
patience: int = 10,
min_delta: float = 0.0,
mode: str = "min",
):
super().__init__()
self.monitor = [] if monitor is None else monitor
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score: float | None = None
self.early_stop = False
[docs]
def is_stopped(self) -> bool:
return self.early_stop
[docs]
def get_score(self, values: dict[str, float]):
if len(self.monitor) == 0:
return super().get_score(values)
for v in self.monitor:
if v not in values.keys():
raise TrainerError(
"Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
f"Available keys: {v}. Please check your configuration.",
)
return sum([i for v, i in values.items() if v in self.monitor])
def __call__(self, current_score: float) -> bool:
if self.best_score is None:
self.best_score = current_score
return False
if self.mode == "min":
improvement = self.best_score - current_score
elif self.mode == "max":
improvement = current_score - self.best_score
else:
raise TrainerError("Mode must be 'min' or 'max'.")
if improvement > self.min_delta:
self.best_score = current_score
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
class _Trainer:
"""
Internal class for managing the training loop in a distributed or standalone setting.
Handles:
- Epoch iteration with training and optional validation
- Mixed precision support (autocast)
- Exponential Moving Average (EMA) model tracking
- Early stopping
- Logging to TensorBoard
- Model checkpoint saving and selection (ALL or BEST)
This class is intended to be used via a context manager
(`with _Trainer(...) as trainer:`) inside the public `Trainer` class.
"""
def __init__(
self,
world_size: int,
global_rank: int,
local_rank: int,
size: int,
train_name: str,
early_stopping: EarlyStopping | None,
data_log: list[str] | None,
save_checkpoint_mode: str,
epochs: int,
epoch: int,
autocast: bool,
it_validation: int | None,
it_lr_update: int | None,
it: int,
model: Model,
model_ema: AveragedModel,
dataloader_training: DataLoader,
dataloader_validation: DataLoader | None = None,
) -> None:
self.world_size = world_size
self.global_rank = global_rank
self.local_rank = local_rank
self.size = size
self.save_checkpoint_mode = save_checkpoint_mode
self.train_name = train_name
self.epochs = epochs
self.epoch = epoch
self.model = model
self.dataloader_training = dataloader_training
self.dataloader_validation = dataloader_validation
self.autocast = autocast
self.model_ema = model_ema
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
self.it_validation = len(dataloader_training) if it_validation is None else it_validation
self.it_lr_update = len(dataloader_training) if it_lr_update is None else it_lr_update
self.it = it
self.tb = SummaryWriter(log_dir=statistics_directory() / self.train_name / "tb")
self._best_checkpoint_path: Path | None = None
self._best_checkpoint_loss: float | None = None
if self.global_rank == 0 and self.save_checkpoint_mode == "BEST":
self._initialize_best_checkpoint_state()
self.data_log: dict[str, tuple[DataLog, int]] = {}
if data_log is not None:
for data in data_log:
self.data_log[data.split("/")[0].replace(":", ".")] = (
DataLog[data.split("/")[1]],
int(data.split("/")[2]),
)
def __enter__(self):
return self
def __exit__(self, exc_type, value, traceback):
"""Closes the SummaryWriter if used."""
if self.tb is not None:
self.tb.close()
self.checkpoint_save(None)
def _initialize_best_checkpoint_state(self) -> None:
"""Bootstrap BEST-checkpoint tracking once, including resume scenarios."""
path = checkpoints_directory() / self.train_name
if not path.exists():
return
all_checkpoints = sorted(path.glob("*.pt"))
best_loss = float("inf")
best_ckpt: Path | None = None
for checkpoint_path in all_checkpoints:
state_dict = torch.load(
checkpoint_path,
map_location=torch.device("cpu"),
weights_only=False,
) # nosec B614
checkpoint_loss = float(state_dict.get("loss", float("inf")))
if checkpoint_loss < best_loss:
best_loss = checkpoint_loss
best_ckpt = checkpoint_path
if best_ckpt is not None:
self._best_checkpoint_path = best_ckpt
self._best_checkpoint_loss = best_loss
for checkpoint_path in all_checkpoints:
if checkpoint_path != best_ckpt:
checkpoint_path.unlink()
def _update_best_checkpoint(self, checkpoint_path: Path, loss: float) -> None:
"""Keep only the current best checkpoint without rescanning all saves."""
is_new_best = self._best_checkpoint_loss is None or loss < self._best_checkpoint_loss
if is_new_best:
previous_best = self._best_checkpoint_path
self._best_checkpoint_loss = loss
self._best_checkpoint_path = checkpoint_path
if previous_best is not None and previous_best != checkpoint_path and previous_best.exists():
previous_best.unlink()
return
checkpoint_path.unlink()
def run(self) -> None:
"""
Launches the training loop, performing one epoch at a time.
Triggers early stopping and resets data augmentations between epochs.
"""
self.dataloader_training.dataset.load("Train")
if self.dataloader_validation is not None:
self.dataloader_validation.dataset.load("Validation")
if State[konfai_state()] != State.TRAIN:
self._validate()
with tqdm.tqdm(
iterable=range(self.epoch, self.epochs),
leave=False,
total=self.epochs,
initial=self.epoch,
desc="Progress",
) as epoch_tqdm:
for self.epoch in epoch_tqdm:
self.train()
if self.early_stopping.is_stopped():
break
self.dataloader_training.dataset.reset_augmentation("Train")
def train(self) -> None:
"""
Performs a full training epoch with support for:
- mixed precision
- DDP / CPU training
- EMA updates
- loss logging and checkpoint saving
- validation at configurable iteration interval
"""
self.model.train()
self.model.module.set_state(NetState.TRAIN)
if self.model_ema is not None:
self.model_ema.eval()
self.model_ema.module.set_state(NetState.TRAIN)
with tqdm.tqdm(
iterable=enumerate(self.dataloader_training),
desc=f"Training : {description(self.model, self.model_ema)}",
total=len(self.dataloader_training),
leave=False,
ncols=0,
) as batch_iter:
for _, batch_sample in batch_iter:
with torch.amp.autocast("cuda", enabled=self.autocast):
self.model(batch_sample)
self.model.module.backward(self.model)
if self.model_ema is not None:
self.model_ema.update_parameters(self.model)
self.it += 1
if (self.it) % self.it_lr_update == 0:
self.model.module.update_lr()
if (self.it) % self.it_validation == 0:
loss = self._train_log(batch_sample)
if self.dataloader_validation is not None:
loss = self._validate()
score = self.early_stopping.get_score(loss)
self.checkpoint_save(score)
if self.early_stopping(score):
break
batch_iter.set_description(f"Training : {description(self.model, self.model_ema)}")
@torch.no_grad()
def _validate(self) -> float:
"""
Executes the validation phase, evaluates loss and metrics.
Updates model states and resets augmentation for validation set.
Returns:
float: Validation loss.
"""
if self.dataloader_validation is None:
return 0
self.model.eval()
self.model.module.set_state(NetState.PREDICTION)
if self.model_ema is not None:
self.model_ema.module.set_state(NetState.PREDICTION)
batch_sample: BatchSample = {}
with tqdm.tqdm(
iterable=enumerate(self.dataloader_validation),
desc=f"Validation : {description(self.model, self.model_ema)}",
total=len(self.dataloader_validation),
leave=False,
ncols=0,
) as batch_iter:
for _, batch_sample in batch_iter:
self.model(batch_sample)
if self.model_ema is not None:
self.model_ema.module(batch_sample)
batch_iter.set_description(f"Validation : {description(self.model, self.model_ema)}")
self.dataloader_validation.dataset.reset_augmentation("Validation")
if dist.is_initialized():
dist.barrier()
self.model.train()
self.model.module.set_state(NetState.TRAIN)
if self.model_ema is not None:
self.model_ema.module.set_state(NetState.TRAIN)
return self._validation_log(batch_sample)
def checkpoint_save(self, loss: float | None) -> None:
"""
Saves model and optimizer states. Keeps either all checkpoints or only the best one.
Args:
loss (float): Current loss used for best checkpoint selection.
"""
if self.global_rank != 0:
return
path = checkpoints_directory() / self.train_name
path.mkdir(parents=True, exist_ok=True)
name = current_date() + ".pt"
save_path = path / name
save_dict = {
"epoch": self.epoch,
"it": self.it,
"loss": loss if loss is not None else 0,
"Model": self.model.module.state_dict(),
}
if self.model_ema is not None:
save_dict["Model_EMA"] = self.model_ema.module.state_dict()
save_dict.update(
{
f"{name}_optimizer_state_dict": network.optimizer.state_dict()
for name, network in self.model.module.get_networks().items()
if network.optimizer is not None
}
)
save_dict.update(
{
f"{name}_it": network._it
for name, network in self.model.module.get_networks().items()
if network.optimizer is not None
}
)
save_dict.update(
{
f"{name}_nb_lr_update": network._nb_lr_update
for name, network in self.model.module.get_networks().items()
if network.optimizer is not None
}
)
torch.save(save_dict, save_path)
if self.save_checkpoint_mode == "BEST" and loss is not None:
self._update_best_checkpoint(save_path, loss)
@torch.no_grad()
def _log(
self,
type_log: str,
batch_sample: BatchSample,
) -> dict[str, float] | None:
"""
Logs losses, metrics and optionally images to TensorBoard.
Args:
type_log (str): "Training" or "Validation".
batch_item (dict): Dictionary of BatchItem from current batch.
Returns:
dict[str, float] | None: Dictionary of aggregated losses and metrics if rank == 0.
"""
models: dict[str, Network] = {"": self.model.module}
if self.model_ema is not None:
models["_EMA"] = self.model_ema.module
measures = DistributedObject.get_measure(
self.world_size,
self.global_rank,
self.local_rank * self.size + self.size - 1,
models,
(
self.it_validation
if type_log == "Training" or self.dataloader_validation is None
else len(self.dataloader_validation)
),
)
if self.global_rank == 0:
images_log = []
if len(self.data_log):
for name, data_type in self.data_log.items():
if name in batch_sample:
data_type[0](
self.tb,
f"{type_log}/{name}",
batch_sample[name].tensor[: self.data_log[name][1]].detach().cpu().numpy(),
self.it,
)
else:
images_log.append(name.replace(":", "."))
for label, model in models.items():
for name, network in model.get_networks().items():
if network.measure is not None:
self.tb.add_scalars(
f"{type_log}/{name}/Loss/{label}",
{k.replace(":", "."): v[1] for k, v in measures[f"{name}{label}"][0].items()},
self.it,
)
self.tb.add_scalars(
f"{type_log}/{name}/Loss_weight/{label}",
{k.replace(":", "."): v[0] for k, v in measures[f"{name}{label}"][0].items()},
self.it,
)
self.tb.add_scalars(
f"{type_log}/{name}/Metric/{label}",
{k.replace(":", "."): v[1] for k, v in measures[f"{name}{label}"][1].items()},
self.it,
)
self.tb.add_scalars(
f"{type_log}/{name}/Metric_weight/{label}",
{k.replace(":", "."): v[0] for k, v in measures[f"{name}{label}"][1].items()},
self.it,
)
if len(images_log):
for name, layer, _ in model.get_layers(
[v.tensor for v in batch_sample.values() if v.is_input],
images_log,
):
self.data_log[name][0](
self.tb,
f"{type_log}/{name}{label}",
layer[: self.data_log[name][1]].detach().cpu().numpy(),
self.it,
)
if type_log == "Training":
for name, network in self.model.module.get_networks().items():
if network.optimizer is not None:
self.tb.add_scalar(
f"{type_log}/{name}/Learning Rate",
network.optimizer.param_groups[0]["lr"],
self.it,
)
if self.global_rank == 0:
loss = {}
for name, network in self.model.module.get_networks().items():
if network.measure is not None:
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][0].items()})
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][1].items()})
return loss
return None
@torch.no_grad()
def _train_log(self, batch_sample: BatchSample) -> dict[str, float]:
"""Wrapper for _log during training."""
return self._log("Training", batch_sample)
@torch.no_grad()
def _validation_log(self, batch_sample: BatchSample) -> dict[str, float]:
"""Wrapper for _log during validation."""
return self._log("Validation", batch_sample)
[docs]
@config()
class Trainer(DistributedObject):
"""
Public API for training a model using the KonfAI framework.
Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
Main responsibilities:
- Initialization from config (via @config)
- Model and EMA setup
- Checkpoint loading and saving
- Distributed setup and launch
Args:
model (ModelLoader): Loader for model architecture.
dataset (DataTrain): Training/validation dataset.
train_name (str): Training session name.
manual_seed (int | None): Random seed.
epochs (int): Number of epochs to run.
it_validation (int | None): Validation interval.
it_lr_update (int | None): Learning rate update interval.
autocast (bool): Enable AMP training.
gradient_checkpoints (list[str] | None): Modules to use gradient checkpointing on.
gpu_checkpoints (list[str] | None): Modules to pin on specific GPUs.
ema_decay (float): EMA decay factor.
data_log (list[str] | None): Logging instructions.
early_stopping (EarlyStopping | None): Optional early stopping config.
save_checkpoint_mode (str): Either "BEST" or "ALL".
"""
def __init__(
self,
model: ModelLoader = ModelLoader(),
dataset: DataTrain = DataTrain(),
train_name: str = "default|TRAIN_01",
manual_seed: int | None = None,
epochs: int = 100,
it_validation: int | None = None,
it_lr_update: int | None = None,
autocast: bool = False,
gradient_checkpoints: list[str] | None = None,
gpu_checkpoints: list[str] | None = None,
ema_decay: float = 0,
data_log: list[str] | None = None,
early_stopping: EarlyStopping | None = None,
save_checkpoint_mode: str = "BEST",
) -> None:
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
raise ConfigError("Trainer requires KONFAI_CONFIG_MODE='Done' before initialization.")
super().__init__(train_name)
self.manual_seed = manual_seed
self.dataset = dataset
self.autocast = autocast
self.epochs = epochs
self.epoch = 0
self.early_stopping = early_stopping
self.it = 0
self.it_validation = it_validation
self.it_lr_update = it_lr_update
self.model = model.get_model(train=True)
self.ema_decay = ema_decay
self.model_ema: torch.optim.swa_utils.AveragedModel | None = None
self.data_log = data_log
modules = []
for i, _ in self.model.named_modules():
modules.append(i)
self.gradient_checkpoints = gradient_checkpoints
self.gpu_checkpoints = gpu_checkpoints
self.save_checkpoint_mode = save_checkpoint_mode
self.config_path_src = config_file()
config_namefile = self.config_path_src.name.replace(".yml", "")
self.config_namefile = statistics_directory() / self.name / f"{config_namefile}_{self.it}.yml"
self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
state = State[konfai_state()]
self.dataset.prepare()
self.model.init(self.autocast, state, self.dataset.get_groups_dest())
self.model.init_outputs_group()
self.model._compute_channels_trace(
self.model,
self.model.in_channels,
self.gradient_checkpoints,
self.gpu_checkpoints,
)
[docs]
def setup(self, world_size: int):
"""
Initializes the training environment:
- Clears previous outputs (unless resuming)
- Initializes model and EMA
- Loads checkpoint (if resuming)
- Prepares dataloaders
Args:
world_size (int): Total number of distributed processes.
"""
state = State[konfai_state()]
if state != State.RESUME and (checkpoints_directory() / self.name).exists():
confirm_overwrite_or_raise(checkpoints_directory() / self.name, "model", TrainerError)
for directory_path in [
statistics_directory(),
checkpoints_directory(),
]:
path = directory_path / self.name
if path.exists():
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
state_dict = {}
if state != State.TRAIN:
state_dict = self._load()
self.model.load(state_dict, init=True, ema=False)
if self.ema_decay > 0:
self.model_ema = AveragedModel(self.model, avg_fn=self._avg_fn)
if state_dict is not None:
self.model_ema.module.load(state_dict, init=False, ema=True)
(statistics_directory() / self.name).mkdir(exist_ok=True)
shutil.copyfile(self.config_path_src, self.config_namefile)
self.dataloader, train_names, validation_names = self.dataset.get_data(world_size // self.size)
with open(statistics_directory() / self.name / f"Train_{self.it}.txt", "w") as f:
for name in train_names:
f.write(name + "\n")
with open(statistics_directory() / self.name / f"Validation_{self.it}.txt", "w") as f:
for name in validation_names:
f.write(name + "\n")
[docs]
def set_model(self, path_to_model: Path) -> None:
self.path_to_model = str(path_to_model)
def __exit__(self, exc_type, value, traceback):
"""Exit training context and trigger save of model/checkpoints."""
super().__exit__(exc_type, value, traceback)
self._save()
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
"""
Loads a previously saved checkpoint from local disk or URL.
Returns:
dict: State dictionary loaded from checkpoint.
"""
if self.path_to_model.startswith("https://"):
try:
state_dict = {
self.path_to_model.split(":")[1]: torch.hub.load_state_dict_from_url(
url=self.path_to_model.split(":")[0], map_location="cpu", check_hash=True
)
}
except Exception:
raise Exception(f"Model : {self.path_to_model} does not exist !")
elif Path(self.path_to_model).exists():
state_dict = torch.load(
str(self.path_to_model), map_location=torch.device("cpu"), weights_only=False
) # nosec B614
else:
raise ValueError(f"Invalid model path entry: {self.path_to_model}")
if "epoch" in state_dict:
self.epoch = state_dict["epoch"]
if "it" in state_dict:
self.it = state_dict["it"]
return state_dict
def _save(self) -> None:
if self.config_namefile.exists():
new_name = f"{self.config_namefile.stem}_{self.it}.yml"
os.rename(
self.config_namefile,
self.config_namefile.parent / new_name,
)
def _avg_fn(self, averaged_model_parameter: float, model_parameter, num_averaged):
"""
EMA update rule used by AveragedModel.
Returns:
torch.Tensor: Blended parameter using decay factor.
"""
return (1 - self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
[docs]
def run_process(
self,
world_size: int,
global_rank: int,
local_rank: int,
dataloaders: list[DataLoader],
):
"""
Launches the actual training process via internal `_Trainer` class.
Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
Args:
world_size (int): Total number of distributed processes.
global_rank (int): Global rank of the current process.
local_rank (int): Local rank within the node.
dataloaders (list[DataLoader]): Training and validation dataloaders.
"""
model = Network.to(self.model, local_rank * self.size) if len(cuda_visible_devices()) else self.model
if dist.is_initialized():
ddp_kwargs: dict[str, object] = {"static_graph": True}
if len(cuda_visible_devices()) and self.size == 1:
ddp_kwargs.update({"device_ids": [local_rank], "output_device": local_rank})
model = DDP(model, **ddp_kwargs)
else:
model = Model(model)
if self.model_ema is not None:
self.model_ema.module = Network.to(self.model_ema.module, local_rank * self.size)
with _Trainer(
world_size,
global_rank,
local_rank,
self.size,
self.name,
self.early_stopping,
self.data_log,
self.save_checkpoint_mode,
self.epochs,
self.epoch,
self.autocast,
self.it_validation,
self.it_lr_update,
self.it,
model,
self.model_ema,
*dataloaders,
) as t:
t.run()
[docs]
def build_train(
command: State = State.TRAIN,
model: Path | str | None = None,
config: Path | str = Path("./Config.yml"),
checkpoints_dir: Path | str = Path("./Checkpoints/"),
statistics_dir: Path | str = Path("./Statistics/"),
) -> DistributedObject:
"""
Build and return the configured training workflow without executing it.
Parameters
----------
command : State, optional
Training command variant, typically ``State.TRAIN`` or ``State.RESUME``.
model : Path | str | None, optional
Checkpoint path used when resuming training.
config : Path | str, optional
Training configuration file.
checkpoints_dir : Path | str, optional
Output directory for checkpoints.
statistics_dir : Path | str, optional
Output directory for statistics and logs.
Returns
-------
DistributedObject
Configured trainer object ready to be executed by the runtime wrapper.
"""
configure_workflow_environment(
config_path=config,
root="Trainer",
state=command,
path_env={
"KONFAI_CHECKPOINTS_DIRECTORY": checkpoints_dir,
"KONFAI_STATISTICS_DIRECTORY": statistics_dir,
},
)
os.environ["KONFAI_CONFIG_MODE"] = "Done"
trainer = apply_config()(Trainer)()
if model is not None:
trainer.set_model(Path(model))
return trainer
[docs]
@run_distributed_app
def train(
command: State = State.TRAIN,
overwrite: bool = False,
model: Path | str | None = None,
gpu: list[int] | None = cuda_visible_devices(),
cpu: int | None = None,
quiet: bool = False,
tensorboard: bool = False,
config: Path | str = Path("./Config.yml"),
checkpoints_dir: Path | str = Path("./Checkpoints/"),
statistics_dir: Path | str = Path("./Statistics/"),
) -> DistributedObject:
"""
Build and execute the configured training workflow.
This compatibility wrapper preserves the historical CLI-facing API while
delegating the pure build step to :func:`build_train`.
"""
del overwrite, gpu, cpu, quiet, tensorboard
return build_train(
command=command,
model=model,
config=config,
checkpoints_dir=checkpoints_dir,
statistics_dir=statistics_dir,
)
if __name__ == "__main__":
train(State.TRAIN, False, None)