# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Dataset assembly, subset selection, and dataloader orchestration for KonfAI."""
import math
import os
import random
import threading
import traceback
from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial
from typing import TypeAlias, cast
import numpy as np
import torch
import tqdm
from torch.cuda import device_count
from torch.utils import data
from torch.utils.data import DataLoader, Sampler
from konfai import konfai_root, konfai_state
from konfai.data.augmentation import DataAugmentationsList
from konfai.data.patching import DatasetManager, DatasetPatch
from konfai.data.transform import Transform, TransformLoader
from konfai.utils.config import config
from konfai.utils.dataset import Attribute, Dataset
from konfai.utils.errors import DatasetManagerError
from konfai.utils.runtime import State, get_cpu_info, get_memory, get_memory_info, memory_forecast
from konfai.utils.utils import SUPPORTED_EXTENSIONS, split_path_spec
[docs]
class Group(dict[str, GroupTransform]):
"""Mapping of destination group names to transform pipelines."""
def __init__(
self,
groups_dest: dict[str, GroupTransform] = {"default|Labels": GroupTransform()},
):
super().__init__(groups_dest)
[docs]
class GroupMetric(dict[str, GroupTransformMetric]):
"""Metric-oriented variant of :class:`Group` used during evaluation."""
def __init__(
self,
groups_dest: dict[str, GroupTransformMetric] = {"default|group_dest": GroupTransformMetric()},
):
super().__init__(groups_dest)
[docs]
class CustomSampler(Sampler[int]):
"""Simple sampler that optionally shuffles indices without distributed logic."""
def __init__(self, size: int, shuffle: bool = False) -> None:
self.size = size
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))))
def __len__(self) -> int:
return self.size
[docs]
@dataclass(frozen=True)
class DataItem:
"""Single tensor sample together with dataset metadata and patch indices."""
name: str
tensor: torch.Tensor
attribute: Attribute
x: int
a: int
p: int
is_input: bool
[docs]
@dataclass(frozen=True)
class BatchDataItem:
"""Batch-level representation of multiple :class:`DataItem` objects."""
name: list[str]
tensor: torch.Tensor # [B, ...]
attribute: list[Attribute]
x: list[int]
a: list[int]
p: list[int]
is_input: bool
Sample: TypeAlias = dict[str, DataItem]
BatchSample: TypeAlias = dict[str, BatchDataItem]
[docs]
def collate_konfai(batch: list[Sample]) -> BatchSample:
"""Collate KonfAI samples into the batch structure expected by the workflows."""
batch_sample: BatchSample = {}
for k in batch[0].keys():
items = [b[k] for b in batch]
batch_sample[k] = BatchDataItem(
tensor=torch.stack([it.tensor for it in items], dim=0),
x=[it.x for it in items],
a=[it.a for it in items],
p=[it.p for it in items],
attribute=[it.attribute for it in items],
name=[it.name for it in items],
is_input=items[0].is_input,
)
return batch_sample
[docs]
class DatasetIter(data.Dataset):
"""Torch dataset view over KonfAI dataset managers and patch mappings."""
def __init__(
self,
rank: int,
data: dict[str, list[DatasetManager]],
mapping: list[tuple[int, int, int]],
groups_src: Mapping[str, Group | GroupMetric],
inline_augmentations: bool,
data_augmentations_list: list[DataAugmentationsList],
patch_size: list[int] | None,
overlap: int | None,
buffer_size: int,
use_cache=True,
) -> None:
self.rank = rank
self.data = data
self.mapping = mapping
self.patch_size = patch_size
self.overlap = overlap
self.groups_src = groups_src
self.data_augmentations_list = data_augmentations_list
self.use_cache = use_cache
self.nb_dataset = len(data[list(data.keys())[0]])
self.buffer_size = buffer_size
self._index_cache: list[int] = []
self._index_cache_lookup: set[int] = set()
self.inline_augmentations = inline_augmentations
[docs]
def get_patch_config(self) -> tuple[list[int] | None, int | None]:
return self.patch_size, self.overlap
[docs]
def to(self, device: int):
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
self.groups_src[group_src][group_dest].to(device)
for data_augmentations in self.data_augmentations_list:
for data_augmentation in data_augmentations.data_augmentations:
data_augmentation.to(device)
[docs]
def get_dataset_from_index(self, group_dest: str, index: int) -> DatasetManager:
return self.data[group_dest][index]
[docs]
def reset_augmentation(self, label):
if self.inline_augmentations and len(self.data_augmentations_list) > 0:
for index in range(self.nb_dataset):
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
self.data[group_dest][index].unload_augmentation()
self.data[group_dest][index].reset_augmentation()
self.load(label + " Augmentation")
[docs]
def load(self, label: str):
if self.use_cache:
memory_init = get_memory()
indexs = list(range(self.nb_dataset))
if len(indexs) > 0:
memory_lock = threading.Lock()
def desc(i: int = 0):
return (
f"Caching {label}: "
f"{get_memory_info()} | "
f"{memory_forecast(memory_init, i, self.nb_dataset)} | "
f"{get_cpu_info()}"
)
pbar = tqdm.tqdm(total=len(indexs), desc=desc(), leave=False)
stop_event = threading.Event()
def process(index):
if stop_event.is_set():
return
self._load_data(index)
with memory_lock:
pbar.set_description(desc(pbar.n + 1))
pbar.update(1)
cpu_count = os.cpu_count() or 1
try:
with ThreadPoolExecutor(
max_workers=cpu_count // (device_count() if device_count() > 0 else 1)
) as executor:
future_to_index = {executor.submit(process, index): index for index in indexs}
for fut in as_completed(future_to_index):
index = future_to_index[fut]
try:
fut.result()
except Exception as e:
stop_event.set()
for f in future_to_index:
f.cancel()
tb = traceback.format_exc()
raise RuntimeError(
f"Error while caching {label} (index={index})\n"
f"{type(e).__name__}: {e}\n\n"
f"Traceback (worker):\n{tb}"
) from e
except KeyboardInterrupt:
stop_event.set()
try:
for f in future_to_index:
f.cancel()
except Exception: # nosec B110
pass
raise
finally:
pbar.close()
def _load_data(self, index: int, augmentation_index: int | None = None) -> bool:
loaded = False
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
loaded |= self.load_data(group_src, group_dest, index, augmentation_index)
if loaded and index not in self._index_cache_lookup:
self._index_cache.append(index)
self._index_cache_lookup.add(index)
return loaded
[docs]
def load_data(self, group_src: str, group_dest: str, index: int, augmentation_index: int | None = None) -> bool:
item = self.data[group_dest][index]
if augmentation_index is not None and item.can_stream_patch(augmentation_index):
return False
try:
item.load(
self.groups_src[group_src][group_dest].transforms,
self.data_augmentations_list,
load_augmentations=not self.inline_augmentations,
)
except Exception as e:
raise RuntimeError(
f"Error while loading data "
f"(group_src={group_src}, group_dest={group_dest}, "
f"index={index}, name={item.name}) : "
f"{type(e).__name__}: {e}"
) from e
return True
def _unload_data(self, index: int) -> None:
if index in self._index_cache_lookup:
self._index_cache_lookup.remove(index)
self._index_cache.remove(index)
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
self.unload_data(group_dest, index)
[docs]
def unload_data(self, group_dest: str, index: int) -> None:
return self.data[group_dest][index].unload()
def __len__(self) -> int:
return len(self.mapping)
def __getitem__(self, index: int) -> Sample:
sample: Sample = {}
x, a, p = self.mapping[index]
needs_full_load = any(
not self.data[group_dest][x].can_stream_patch(a)
for group_src in self.groups_src
for group_dest in self.groups_src[group_src]
)
if x not in self._index_cache_lookup and needs_full_load:
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
self._unload_data(self._index_cache[0])
self._load_data(x, a)
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
dataset = self.data[group_dest][x]
sample[f"{group_dest}"] = DataItem(
dataset.name,
dataset.get_data(
p,
a,
self.groups_src[group_src][group_dest].patch_transforms,
self.groups_src[group_src][group_dest].is_input,
),
dataset.cache_attributes[a],
x,
a,
p,
self.groups_src[group_src][group_dest].is_input,
)
return sample
[docs]
class Subset:
def __init__(
self,
subset: str | list[int] | list[str] | None = None,
shuffle: bool = True,
) -> None:
self.subset = subset
self.shuffle = shuffle
@staticmethod
def _read_names_from_file(filename: str) -> list[str]:
with open(filename) as f:
return [name.strip() for name in f if name.strip()]
[docs]
def requires_infos(self) -> bool:
"""Return whether this subset implementation needs per-sample metadata."""
return self.__class__.__call__ is not Subset.__call__
@staticmethod
def _is_slice_selector(subset: str) -> bool:
start, sep, end = subset.partition(":")
if sep == "":
return False
return start.lstrip("-").isdigit() and end.lstrip("-").isdigit()
def _resolve_selector(self, subset: str | int, names: list[str]) -> tuple[set[int], bool]:
size = len(names)
name_to_index = {name: i for i, name in enumerate(names)}
if isinstance(subset, int):
return {subset}, False
if subset.startswith("~"):
excluded = subset[1:]
if os.path.exists(excluded):
exclude_names = set(self._read_names_from_file(excluded))
return {i for i, name in enumerate(names) if name in exclude_names}, True
if excluded in name_to_index:
return {name_to_index[excluded]}, True
return set(), True
if os.path.exists(subset):
selected_names = set(self._read_names_from_file(subset))
return {i for i, name in enumerate(names) if name in selected_names}, False
if self._is_slice_selector(subset):
start, _, end = subset.partition(":")
r = np.clip(
np.asarray([int(start), int(end)]),
0,
size,
)
return set(range(int(r[0]), int(r[1]))), False
if subset in name_to_index:
return {name_to_index[subset]}, False
return set(), False
def _get_index(self, subset: str | int, names: list[str]) -> list[int]:
index, is_exclusion = self._resolve_selector(subset, names)
if is_exclusion:
return [i for i in range(len(names)) if i not in index]
return list(index)
def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
names = sorted(names)
size = len(names)
if self.subset is None:
index = list(range(0, size))
elif isinstance(self.subset, list):
if len(self.subset) == 0:
index = []
else:
include_index: set[int] = set()
exclude_index: set[int] = set()
has_include = False
for s in self.subset:
resolved_index, is_exclusion = self._resolve_selector(s, names)
if is_exclusion:
exclude_index.update(resolved_index)
else:
include_index.update(resolved_index)
has_include = True
index_set = include_index if has_include else set(range(size))
index = list(index_set.difference(exclude_index))
else:
index = self._get_index(self.subset, names)
if self.shuffle:
index = random.sample(index, len(index)) # nosec B311
return {names[i] for i in index}
def __str__(self):
return "Subset : " + str(self.subset) + " shuffle : " + str(self.shuffle)
[docs]
class TrainSubset(Subset):
def __init__(
self,
subset: str | list[int] | list[str] | None = None,
shuffle: bool = True,
) -> None:
super().__init__(subset, shuffle)
[docs]
class PredictionSubset(Subset):
def __init__(self, subset: str | list[int] | list[str] | None = None) -> None:
super().__init__(subset, False)
[docs]
class Data(ABC):
"""Abstract base class shared by training, prediction, and evaluation datasets."""
@staticmethod
def _configured_transform_requires_single_process(classpath: str) -> bool:
for transform_name in classpath.split("|"):
candidate = transform_name.split(":")[-1].split(".")[-1].split("/")[0]
if candidate == "KonfAIInference":
return True
return False
@classmethod
def _groups_require_single_process_loading(cls, groups_src: Mapping[str, Group | GroupMetric]) -> bool:
for group in groups_src.values():
for group_transform in group.values():
for configured_transforms in (group_transform._transforms, group_transform._patch_transforms):
if configured_transforms is None:
continue
if any(
cls._configured_transform_requires_single_process(classpath)
for classpath in configured_transforms
):
return True
return False
@staticmethod
def _read_names_from_file(filename: str) -> list[str]:
with open(filename) as f:
return [name.strip() for name in f if name.strip()]
@classmethod
def _resolve_name_selectors(cls, selectors: list[str]) -> set[str]:
resolved_names: set[str] = set()
for selector in selectors:
if os.path.exists(selector):
resolved_names.update(cls._read_names_from_file(selector))
else:
resolved_names.add(selector)
return resolved_names
@abstractmethod
def __init__(
self,
dataset_filenames: list[str],
groups_src: Mapping[str, Group | GroupMetric],
patch: DatasetPatch | None,
use_cache: bool,
subset: Subset,
batch_size: int,
validation: float | str | list[int] | list[str] | None,
inline_augmentations: bool,
data_augmentations_list: dict[str, DataAugmentationsList],
num_workers: int | None,
pin_memory: bool,
prefetch_factor: int | None,
persistent_workers: bool | None,
) -> None:
self.dataset_filenames = dataset_filenames
self.subset = subset
self.groups_src = groups_src
self.patch = patch
self.validation = validation
self.data_augmentations_list = data_augmentations_list
self.batch_size = batch_size
self.use_cache = use_cache
self.inline_augmentations = inline_augmentations
self.requires_single_process_loading = self._groups_require_single_process_loading(groups_src)
self.datasetIter = partial(
DatasetIter,
groups_src=self.groups_src,
inline_augmentations=inline_augmentations,
data_augmentations_list=list(self.data_augmentations_list.values()),
patch_size=self.patch.patch_size if self.patch is not None else None,
overlap=self.patch.overlap if self.patch is not None else None,
buffer_size=batch_size + 1,
use_cache=use_cache,
)
resolved_num_workers = num_workers
if self.requires_single_process_loading:
resolved_num_workers = 0
elif resolved_num_workers is None:
resolved_num_workers = max(1, min(os.cpu_count() or 1, 4)) if not use_cache else 0
self.dataLoader_args = {
"num_workers": resolved_num_workers,
"pin_memory": pin_memory,
"collate_fn": collate_konfai,
}
if resolved_num_workers > 0:
self.dataLoader_args["prefetch_factor"] = 2 if prefetch_factor is None else prefetch_factor
self.dataLoader_args["persistent_workers"] = True if persistent_workers is None else persistent_workers
self.data: list[list[dict[str, list[DatasetManager]]]] = []
self.mapping: list[list[list[tuple[int, int, int]]]] = []
self.datasets: dict[str, Dataset] = {}
self._prepared_data: dict[str, list[DatasetManager]] | None = None
self._prepared_mapping: list[tuple[int, int, int]] = []
self._prepared_validation_mapping: list[tuple[int, int, int]] = []
self._prepared_train_names: list[str] = []
self._prepared_validation_names: list[str] = []
[docs]
def prepare(self) -> None:
"""Instantiate config-driven transforms and augmentations before runtime."""
if self._prepared_data is not None:
return
model_have_input = False
last_group_src: str | None = None
for group_src in self.groups_src:
last_group_src = group_src
for group_dest in self.groups_src[group_src]:
self.groups_src[group_src][group_dest].prepare(group_src, group_dest)
model_have_input |= self.groups_src[group_src][group_dest].is_input
if self.patch is not None:
self.patch.init()
if not model_have_input:
raise DatasetManagerError(
"At least one group must be defined with 'is_input: true' to provide input to the network."
)
if last_group_src is not None:
for key, data_augmentations in self.data_augmentations_list.items():
data_augmentations.prepare(key)
self._prepare_datasets()
def _resolve_dataset_sources(self) -> dict[str, list[tuple[str, bool]]]:
datasets: dict[str, list[tuple[str, bool]]] = {}
if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
raise DatasetManagerError("No dataset filenames were provided")
self.datasets = {}
for dataset_filename in self.dataset_filenames:
if dataset_filename is None:
raise DatasetManagerError(
"Invalid dataset entry: 'None' received.",
"Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, "
"'./Dataset/:a:mha', './Dataset/:i:mha').",
"Please check your 'dataset_filenames' list for missing or null entries.",
)
filename, flag, file_format = split_path_spec(
dataset_filename,
default_format="mha",
allowed_flags={"a", "i"},
supported_extensions=SUPPORTED_EXTENSIONS,
)
append = flag != "i"
if file_format not in SUPPORTED_EXTENSIONS:
raise DatasetManagerError(
f"Unsupported file format '{file_format}'.",
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}",
)
dataset = Dataset(filename, file_format)
self.datasets[filename] = dataset
for group in self.groups_src:
if dataset.is_group_exist(group):
datasets.setdefault(group, []).append((filename, append))
for group_src in self.groups_src:
if group_src not in datasets:
raise DatasetManagerError(
f"Group source '{group_src}' not found in any dataset.",
f"Dataset filenames provided: {self.dataset_filenames}",
f"Available groups across all datasets: "
f"{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}\n"
f"Please check that an entry in the dataset with the name '{group_src}' exists.",
)
for group_dest in self.groups_src[group_src]:
self.groups_src[group_src][group_dest].set_datasets(list(self.datasets.values()))
for _group_src, entries in datasets.items():
for _key, data_augmentations in self.data_augmentations_list.items():
data_augmentations.set_datasets([self.datasets[filename] for filename, _ in entries])
break
return datasets
def _resolve_common_names(
self,
datasets: dict[str, list[tuple[str, bool]]],
) -> tuple[dict[str, dict[str, list[str]]], set[str]]:
dataset_name: dict[str, dict[str, list[str]]] = {}
subset_requires_infos = self.subset.requires_infos()
dataset_info: dict[str, dict[str, dict[str, tuple[list[int], Attribute]]]] | None = (
{} if subset_requires_infos else None
)
empty_infos: dict[str, tuple[list[int], Attribute]] = {}
names: set[str] = set()
for group in self.groups_src:
names_by_group = set()
dataset_name[group] = {}
if dataset_info is not None:
dataset_info[group] = {}
for filename, _ in datasets[group]:
group_names = self.datasets[filename].get_names(group)
names_by_group.update(group_names)
dataset_name[group][filename] = group_names
if dataset_info is not None:
dataset_info[group][filename] = {
name: self.datasets[filename].get_infos(group, name) for name in group_names
}
if len(names) == 0:
names.update(names_by_group)
else:
names = names.intersection(names_by_group)
if len(names) == 0:
raise DatasetManagerError(
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data "
"from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
)
subset_names: set[str] = set()
for group in dataset_name:
subset_names_bygroup: set[str] = set()
for filename, append in datasets[group]:
resolved_subset = self.subset(
dataset_name[group][filename],
dataset_info[group][filename] if dataset_info is not None else empty_infos,
)
if append:
subset_names_bygroup.update(resolved_subset)
elif len(subset_names_bygroup) == 0:
subset_names_bygroup.update(resolved_subset)
else:
subset_names_bygroup = subset_names_bygroup.intersection(resolved_subset)
if len(subset_names) == 0:
subset_names.update(subset_names_bygroup)
else:
subset_names = subset_names.intersection(subset_names_bygroup)
if len(subset_names) == 0:
raise DatasetManagerError(
"All data entries were excluded by the subset filter.",
f"Dataset entries found: {', '.join(names)}",
f"Subset object applied: {self.subset}",
f"Subset requested : {', '.join(subset_names)}",
"None of the dataset entries matched the given subset.",
"Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
"Examples of valid subset formats:",
"\tsubset: [0, 1] # explicit indices",
"\tsubset: [./A.txt, ./B.txt]# union of multiple files",
"\tsubset: 0:10 # slice notation",
"\tsubset: ./Validation.txt # external file",
"\tsubset: None # to disable filtering",
)
return dataset_name, subset_names
def _split_train_validation(
self,
subset_names: list[str],
mapping: list[tuple[int, int, int]],
) -> tuple[list[tuple[int, int, int]], list[tuple[int, int, int]], list[str], list[str]]:
index: list[int] = []
if isinstance(self.validation, float):
if self.validation <= 0 or self.validation >= 1:
raise DatasetManagerError(
"Validation must be a float between 0 and 1.",
f"Received: {self.validation}",
"Example: validation = 0.2 # for a 20% validation split",
)
index = [m[0] for m in mapping[int(math.floor(len(mapping) * (1 - self.validation))) :]]
elif isinstance(self.validation, str):
if ":" in self.validation:
index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
elif os.path.exists(self.validation):
validation_names = []
with open(self.validation) as f:
for name in f:
validation_names.append(name.strip())
index = [i for i, n in enumerate(subset_names) if n in validation_names]
else:
raise DatasetManagerError(
f"Invalid string value for 'validation': '{self.validation}'",
"Expected one of the following formats:",
"\t• A slice string like '0:10'",
"\t• A path to a text file listing validation sample names (e.g., './val.txt')",
"\t• A list of text files listing validation sample names",
"\t• A float between 0 and 1 (e.g., 0.2)",
"\t• A list of sample names or indices",
"The provided value is neither a valid slice nor a readable file.",
"Please fix your 'validation' setting in the configuration.",
)
elif isinstance(self.validation, list):
if len(self.validation) == 0:
index = []
elif all(isinstance(item, int) for item in self.validation):
index = cast(list[int], self.validation)
elif all(isinstance(item, str) for item in self.validation):
validation_name_set = self._resolve_name_selectors(cast(list[str], self.validation))
index = [i for i, n in enumerate(subset_names) if n in validation_name_set]
else:
element_types = sorted({type(item).__name__ for item in self.validation})
raise DatasetManagerError(
"Invalid list type for 'validation': elements of type " f"{element_types} are not supported.",
"Supported list element types are:",
"\t• int → list of indices (e.g., [0, 1, 2])",
"\t• str → list of sample names or file paths",
f"Received list: {self.validation}",
)
index_set = set(index)
train_mapping = [m for m in mapping if m[0] not in index_set]
validate_mapping = [m for m in mapping if m[0] in index_set]
if len(train_mapping) == 0:
raise DatasetManagerError(
"No data left for training after applying the validation split.",
f"Dataset size: {len(mapping)}",
f"Validation setting: {self.validation}",
"Please reduce the validation size, increase the dataset, or disable validation.",
)
if self.validation is not None and len(validate_mapping) == 0:
raise DatasetManagerError(
"No data left for validation after applying the validation split.",
f"Dataset size: {len(mapping)}",
f"Validation setting: {self.validation}",
"Please increase the validation size, increase the dataset, or disable validation.",
)
validation_names = [name for i, name in enumerate(subset_names) if i in index_set]
validation_names_set = set(validation_names)
train_names = [name for name in subset_names if name not in validation_names_set]
return train_mapping, validate_mapping, train_names, validation_names
def _prepare_datasets(self) -> None:
"""Resolve dataset files, validate subsets, and precompute train/validation mappings."""
datasets = self._resolve_dataset_sources()
dataset_name, subset_names = self._resolve_common_names(datasets)
subset_names_list = list(subset_names)
data, mapping = self._get_datasets(subset_names_list, dataset_name)
train_mapping, validate_mapping, train_names, validation_names = self._split_train_validation(
subset_names_list, mapping
)
self._prepared_data = data
self._prepared_mapping = train_mapping
self._prepared_validation_mapping = validate_mapping
self._prepared_train_names = train_names
self._prepared_validation_names = validation_names
def _get_datasets(
self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]
) -> tuple[dict[str, list[DatasetManager]], list[tuple[int, int, int]]]:
nb_dataset = len(names)
nb_patch: list[list[int]]
data = {}
mapping = []
source_filename_by_group: dict[str, dict[str, str]] = {}
nb_augmentation = np.max(
[
int(np.sum([data_augmentation.nb for data_augmentation in self.data_augmentations_list.values()]) + 1),
1,
]
)
for group_src, filenames_by_group in dataset_name.items():
source_filename_by_group[group_src] = {}
for filename, group_names in filenames_by_group.items():
for name in group_names:
source_filename_by_group[group_src].setdefault(name, filename)
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
data[group_dest] = [
DatasetManager(
i,
group_src,
group_dest,
name,
self.datasets[source_filename_by_group[group_src][name]],
patch=self.patch,
transforms=self.groups_src[group_src][group_dest].transforms,
data_augmentations_list=list(self.data_augmentations_list.values()),
)
for i, name in enumerate(names)
]
nb_patch = [[dataset.get_size(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
for x in range(nb_dataset):
for y in range(nb_augmentation):
for z in range(nb_patch[x][y]):
mapping.append((x, y, z))
return data, mapping
[docs]
def get_groups_dest(self):
groups_dest = []
for group_src in self.groups_src:
for group_dest in self.groups_src[group_src]:
groups_dest.append(group_dest)
return groups_dest
@staticmethod
def _split(mapping: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
if len(mapping) == 0:
return [[] for _ in range(world_size)]
mappings: list[list[tuple[int, int, int]]] = []
if konfai_state() == str(State.PREDICTION) or konfai_state() == str(State.EVALUATION):
mapping_by_index: dict[int, list[tuple[int, int, int]]] = {}
for entry in mapping:
mapping_by_index.setdefault(entry[0], []).append(entry)
unique_index = np.asarray(sorted(mapping_by_index))
for shard in np.array_split(unique_index, world_size):
shard_mapping: list[tuple[int, int, int]] = []
for dataset_index in shard.tolist():
shard_mapping.extend(mapping_by_index[int(dataset_index)])
mappings.append(shard_mapping)
else:
size = len(mapping)
for rank in range(world_size):
start = (size * rank) // world_size
end = (size * (rank + 1)) // world_size
mappings.append(mapping[start:end])
return mappings
@staticmethod
def _remap_dataset_indices(mapping_tmp: list[tuple[int, int, int]]) -> tuple[list[int], list[tuple[int, int, int]]]:
"""Compress sparse dataset indices into local contiguous indices for one loader shard."""
local_indices: list[int] = []
index_map: dict[int, int] = {}
remapped_mapping: list[tuple[int, int, int]] = []
for dataset_index, augmentation_index, patch_index in mapping_tmp:
local_index = index_map.get(dataset_index)
if local_index is None:
local_index = len(local_indices)
local_indices.append(dataset_index)
index_map[dataset_index] = local_index
remapped_mapping.append((local_index, augmentation_index, patch_index))
return local_indices, remapped_mapping
[docs]
def get_data(self, world_size: int) -> tuple[list[list[DataLoader]], list[str], list[str]]:
if self._prepared_data is None:
raise DatasetManagerError("Dataset configuration was not prepared before runtime data loading.")
self.data = []
self.mapping = []
train_mappings = Data._split(self._prepared_mapping, world_size)
validate_mappings = Data._split(self._prepared_validation_mapping, world_size)
for i, (train_mapping, validate_mapping) in enumerate(zip(train_mappings, validate_mappings)):
mappings = [train_mapping]
if len(validate_mapping):
mappings += [validate_mapping]
self.data.append([])
self.mapping.append([])
for mapping_tmp in mappings:
indexs, remapped_mapping = self._remap_dataset_indices(mapping_tmp)
self.data[i].append({k: [v[it] for it in indexs] for k, v in self._prepared_data.items()})
self.mapping[i].append(remapped_mapping)
data_loaders: list[list[DataLoader]] = []
for i, (datas, mappings) in enumerate(zip(self.data, self.mapping)):
data_loaders.append([])
for dataset_items, mapping in zip(datas, mappings):
data_loaders[i].append(
DataLoader(
dataset=self.datasetIter(
rank=i,
data=dataset_items,
mapping=mapping,
),
sampler=CustomSampler(len(mapping), self.subset.shuffle),
batch_size=self.batch_size,
**self.dataLoader_args,
)
)
return data_loaders, self._prepared_train_names, self._prepared_validation_names
def __str__(self) -> str:
params = {
"dataset_filenames": self.dataset_filenames,
"groups_src": self.groups_src,
"patch": self.patch,
"use_cache": self.use_cache,
"subset": self.subset,
"batch_size": self.batch_size,
"validation": self.validation,
"inline_augmentations": self.inline_augmentations,
"data_augmentations_list": self.data_augmentations_list,
}
return str(params)
def __repr__(self) -> str:
return str(self)
[docs]
@config("Dataset")
class DataTrain(Data):
"""Dataset configuration used by the training workflow."""
def __init__(
self,
dataset_filenames: list[str] = ["default|./Dataset:mha"],
groups_src: dict[str, Group] = {"default|Labels": Group()},
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
inline_augmentations: bool = False,
patch: DatasetPatch | None = DatasetPatch(),
use_cache: bool = True,
subset: TrainSubset = TrainSubset(),
batch_size: int = 1,
validation: float | str | list[int] | list[str] | None = 0.2,
num_workers: int | None = None,
pin_memory: bool = False,
prefetch_factor: int | None = None,
persistent_workers: bool | None = None,
) -> None:
super().__init__(
dataset_filenames,
groups_src,
patch,
use_cache,
subset,
batch_size,
validation,
inline_augmentations,
augmentations if augmentations else {},
num_workers,
pin_memory,
prefetch_factor,
persistent_workers,
)
[docs]
@config("Dataset")
class DataPrediction(Data):
"""Dataset configuration used by the prediction workflow."""
def __init__(
self,
dataset_filenames: list[str] = ["default|./Dataset"],
groups_src: dict[str, Group] = {"default": Group()},
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
patch: DatasetPatch | None = DatasetPatch(),
subset: PredictionSubset = PredictionSubset(),
batch_size: int = 1,
num_workers: int | None = None,
pin_memory: bool = False,
prefetch_factor: int | None = None,
persistent_workers: bool | None = None,
) -> None:
super().__init__(
dataset_filenames=dataset_filenames,
groups_src=groups_src,
patch=patch,
use_cache=False,
subset=subset,
batch_size=batch_size,
validation=None,
inline_augmentations=False,
data_augmentations_list=augmentations if augmentations else {},
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=False if persistent_workers is None else persistent_workers,
)
[docs]
@config("Dataset")
class DataMetric(Data):
"""Dataset configuration used by the evaluation workflow."""
def __init__(
self,
dataset_filenames: list[str] = ["default|./Dataset:mha"],
groups_src: dict[str, GroupMetric] = {"default": GroupMetric()},
subset: PredictionSubset = PredictionSubset(),
validation: str | list[int] | list[str] | None = None,
num_workers: int | None = None,
pin_memory: bool = False,
prefetch_factor: int | None = None,
persistent_workers: bool | None = None,
) -> None:
super().__init__(
dataset_filenames=dataset_filenames,
groups_src=groups_src,
patch=None,
use_cache=True,
subset=subset,
batch_size=1,
validation=validation,
data_augmentations_list={},
inline_augmentations=False,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
)