Source code for konfai.data.patching

# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Patch extraction, accumulation, and patch-combination helpers for KonfAI."""

import copy
import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial

import numpy as np
import SimpleITK as sitk  # noqa: N813
import torch
import torch.nn.functional as F  # noqa: N812

from konfai.data.augmentation import DataAugmentationsList
from konfai.data.transform import Clip, Normalize, Save, Standardize, TensorCast, Transform
from konfai.utils.config import apply_config, config
from konfai.utils.dataset import Attribute, Dataset
from konfai.utils.utils import SUPPORTED_EXTENSIONS, get_module, get_patch_slices_from_shape, split_path_spec


[docs] @dataclass(frozen=True) class PatchReadPlan: """Precomputed slicing and padding instructions for one patch request.""" data_slices: tuple[slice, ...] reflect_padding: tuple[int, ...] constant_padding: tuple[int, ...] concatenate_extend_slice: bool
[docs] class PathCombine(ABC): """Base class for overlap-aware weighting schemes applied during patch assembly.""" def __init__(self) -> None: self.data: torch.Tensor self.overlap: int self._data_per_device: dict[torch.device, torch.Tensor] = {} """ A = slice(0, overlap) B = slice(-overlap, None) C = slice(overlap, -overlap) 1D A+B 2D : AA+AB+BA+BB AC+BC CA+CB 3D : AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB CAA+CAB+CBA+CBB ACA+ACB+BCA+BCB AAC+ABC+BAC+BBC CCA+CCB CAC+CBC ACC+BCC """
[docs] def set_patch_config(self, patch_size: list[int], overlap: int): self._data_per_device.clear() self.data = F.pad( torch.ones([size - overlap * 2 for size in patch_size]), [overlap] * 2 * len(patch_size), mode="constant", value=0, ) self.data = self._set_function(self.data, overlap) dim = len(patch_size) a = slice(0, overlap) b = slice(-overlap, None) c = slice(overlap, -overlap) for i in range(dim): slices_badge = list(itertools.product(*[[a, b] for _ in range(dim - i)])) for indexs in itertools.combinations([0, 1, 2], i): result = [] for slices_tuple in slices_badge: slices_list = list(slices_tuple) for index in indexs: slices_list.insert(index, c) result.append(tuple(slices_list)) for patch, s in zip(PathCombine._normalise([self.data[s] for s in result]), result): self.data[s] = patch
@staticmethod def _normalise(patchs: list[torch.Tensor]) -> list[torch.Tensor]: data_sum = torch.sum(torch.concat([patch.unsqueeze(0) for patch in patchs], dim=0), dim=0) return [d / data_sum for d in patchs] def __call__(self, tensor: torch.Tensor) -> torch.Tensor: if tensor.device not in self._data_per_device: self._data_per_device[tensor.device] = self.data.to(tensor.device) return self._data_per_device[tensor.device] * tensor @abstractmethod def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor: pass
[docs] class Mean(PathCombine): """Uniform patch-combination strategy for overlapping predictions.""" def __init__(self) -> None: super().__init__() def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor: return torch.ones_like(self.data)
[docs] class Cosinus(PathCombine): """Cosine-based weighting strategy for smoother overlap blending.""" def __init__(self) -> None: super().__init__() def _function_sides(self, overlap: int, x: float): return np.clip(np.cos(np.pi / (2 * (overlap + 1)) * x), 0, 1) def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor: image = sitk.GetImageFromArray(np.asarray(data, dtype=np.uint8)) danielsson_distance_map_image_filter = sitk.DanielssonDistanceMapImageFilter() distance = torch.tensor(sitk.GetArrayFromImage(danielsson_distance_map_image_filter.Execute(image))) return distance.apply_(partial(self._function_sides, overlap))
[docs] class Accumulator: """Accumulate patch predictions and reassemble them into a full tensor.""" def __init__( self, patch_slices: list[tuple[slice, ...]], patch_size: list[int], patch_combine: PathCombine | None = None, batch: bool = True, ) -> None: self._layer_accumulator: list[torch.Tensor | None] = [None] * len(patch_slices) self.patch_slices: list[tuple[slice, ...]] = [] if patch_size is not None and not all(p == 0 for p in patch_size): for patch in patch_slices: slices: list[slice] = [] for s, shape in zip(patch, patch_size): slices.append(slice(s.start, s.start + shape)) self.patch_slices.append(tuple(slices)) else: self.patch_slices = patch_slices self.shape = max([[v.stop for v in patch] for patch in patch_slices]) self.patch_size = patch_size self.patch_combine = patch_combine self.batch = batch
[docs] def add_layer(self, index: int, layer: torch.Tensor) -> None: self._layer_accumulator[index] = layer
[docs] def is_full(self) -> bool: return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
[docs] def assemble(self) -> torch.Tensor: n = 2 if self.batch else 1 if self._layer_accumulator[0] is not None: result = torch.zeros( ( list(self._layer_accumulator[0].shape[:n]) + list(max([[v.stop for v in patch] for patch in self.patch_slices])) ), dtype=self._layer_accumulator[0].dtype, ).to(self._layer_accumulator[0].device) for patch_slice, data in zip(self.patch_slices, self._layer_accumulator): if data is not None: slices_dest = tuple([slice(result.shape[i]) for i in range(n)] + list(patch_slice)) for dim, s in enumerate(patch_slice): if s.stop - s.start == 1: data = data.unsqueeze(dim=dim + n) if self.patch_combine is not None: result[slices_dest] += self.patch_combine(data) else: result[slices_dest] = data result = result[tuple([slice(None, None)] + [slice(0, s) for s in self.shape])] self._layer_accumulator.clear() return result
[docs] class Patch(ABC): """Abstract base class for dataset-level and model-level patch definitions.""" @abstractmethod def __init__( self, patch_size: list[int], overlap: int | None, pad_value: float | None = 0, extend_slice: int = 0, ) -> None: if extend_slice != 0 and patch_size is not None and patch_size[0] != 1: raise ValueError( "`extend_slice` can only be used when patch_size[0] == 1 " f"(got patch_size[0]={patch_size[0]}, extend_slice={extend_slice})" ) self.patch_size = patch_size self.overlap = overlap if isinstance(self.overlap, int): if self.overlap < 0: self.overlap = None self._patch_slices: dict[int, list[tuple[slice, ...]]] = {} self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {} self.pad_value = pad_value self.extend_slice = extend_slice
[docs] def load(self, shape: list[int], a: int = 0) -> None: self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape( self.patch_size, shape, self.overlap )
[docs] @abstractmethod def init(self, key: str): pass
[docs] def get_patch_slices(self, a: int = 0): return self._patch_slices[a]
[docs] def get_read_plan( self, data_shape: list[int] | tuple[int, ...], index: int, a: int, is_input: bool ) -> PatchReadPlan: slices_pre = [slice(None) for _ in data_shape[: -len(self._patch_slices[a][0])]] extend_slice = self.extend_slice if is_input else 0 bottom = extend_slice // 2 top = int(np.ceil(extend_slice / 2)) s = slice( ( self._patch_slices[a][index][0].start - bottom if self._patch_slices[a][index][0].start - bottom >= 0 else 0 ), ( self._patch_slices[a][index][0].stop + top if self._patch_slices[a][index][0].stop + top <= data_shape[len(slices_pre)] else data_shape[len(slices_pre)] ), ) slices = [s] + list(self._patch_slices[a][index][1:]) reflect_padding = [0 for _ in range((len(slices) - 1) * 2)] + [0, 0] if extend_slice > 0 and (s.stop - s.start) < bottom + top + 1: if self._patch_slices[a][index][0].start - bottom < 0: reflect_padding[-2] = bottom - self._patch_slices[a][index][0].start if self._patch_slices[a][index][0].stop + top > data_shape[len(slices_pre)]: reflect_padding[-1] = self._patch_slices[a][index][0].stop + top - data_shape[len(slices_pre)] constant_padding = [] if self.patch_size is not None and not all(p == 0 for p in self.patch_size): for dim_it, _slice in enumerate(reversed(slices)): p = ( 0 if _slice.start + self.patch_size[-dim_it - 1] <= data_shape[-dim_it - 1] else self.patch_size[-dim_it - 1] - (data_shape[-dim_it - 1] - _slice.start) ) constant_padding.append(0) constant_padding.append(p) return PatchReadPlan( data_slices=tuple(slices_pre + slices), reflect_padding=tuple(reflect_padding), constant_padding=tuple(constant_padding), concatenate_extend_slice=extend_slice > 0, )
[docs] def apply_read_plan(self, data: torch.Tensor, plan: PatchReadPlan) -> torch.Tensor: data_sliced = data if any(plan.reflect_padding): data_sliced = F.pad(data_sliced, plan.reflect_padding, "reflect") if any(plan.constant_padding): data_sliced = F.pad( data_sliced, plan.constant_padding, "constant", ( 0 if data_sliced.dtype == torch.uint8 else (self.pad_value if self.pad_value is not None else float(data.min().item())) ), ) if self.patch_size is not None and not all(p == 0 for p in self.patch_size): for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]: data_sliced = torch.squeeze(data_sliced, dim=len(data_sliced.shape) - d - 1) return ( torch.cat([data_sliced[:, i, ...] for i in range(data_sliced.shape[1])], dim=0) if plan.concatenate_extend_slice else data_sliced )
[docs] def get_data(self, data: torch.Tensor, index: int, a: int, is_input: bool) -> list[torch.Tensor]: plan = self.get_read_plan(list(data.shape), index, a, is_input) data_sliced = data[plan.data_slices] return self.apply_read_plan(data_sliced, plan)
[docs] def get_size(self, a: int = 0) -> int: return len(self._patch_slices[a])
[docs] @config("Patch") class DatasetPatch(Patch): """Patch definition applied when sampling data from datasets.""" def __init__( self, patch_size: list[int] = [128, 128, 128], overlap: int | None = None, pad_value: float | None = None, extend_slice: int = 0, ) -> None: super().__init__(patch_size, overlap, pad_value, extend_slice)
[docs] def init(self, key: str = ""): pass
[docs] @config() class ModelPatch(Patch): """Patch definition applied inside model graphs during prediction or training.""" def __init__( self, patch_size: list[int] = [128, 128, 128], overlap: int | None = None, patch_combine: str | None = None, pad_value: float | None = None, extend_slice: int = 0, ) -> None: super().__init__(patch_size, overlap, pad_value, extend_slice) self._patch_combine = patch_combine self.patch_combine: PathCombine | None = None
[docs] def init(self, key: str): if self._patch_combine is not None: module, name = get_module(self._patch_combine, "konfai.data.patching") self.patch_combine = apply_config(key)(getattr(module, name))() if self.patch_size is not None and self.overlap is not None: if self.patch_combine is not None: self.patch_combine.set_patch_config([i for i in self.patch_size if i > 1], self.overlap) else: self.patch_combine = None
[docs] def disassemble(self, *data_list: torch.Tensor) -> Iterator[list[torch.Tensor]]: for i in range(self.get_size()): yield [self.get_data(data, i, 0, True) for data in data_list]
[docs] class DatasetManager: """Cache-backed manager for one dataset case and one source/destination group.""" def __init__( self, index: int, group_src: str, group_dest: str, name: str, dataset: Dataset, patch: DatasetPatch | None, transforms: list[Transform], data_augmentations_list: list[DataAugmentationsList], ) -> None: self.group_src = group_src self.group_dest = group_dest self.name = name self.index = index self.dataset = dataset self.transforms = transforms self.loaded = False self.augmentationLoaded = False self.cache_attributes: list[Attribute] = [] _shape, cache_attribute = self.dataset.get_infos(self.group_src, name) self.base_shape = list(_shape) self.cache_attributes.append(cache_attribute) _shape = list(_shape[1:]) self.data: list[torch.Tensor] = [] self.augmented_data: dict[int, torch.Tensor] = {} self.total_augmentations = 0 for transform_function in transforms: _shape = transform_function.transform_shape(self.group_src, self.name, _shape, cache_attribute) self.patch = ( DatasetPatch( patch_size=patch.patch_size, overlap=patch.overlap, pad_value=patch.pad_value, extend_slice=patch.extend_slice, ) if patch else DatasetPatch(_shape) ) self.patch.load(_shape, 0) self.shape = _shape self.data_augmentations_list = data_augmentations_list self._patch_stream_source: tuple[Dataset, str, list[int], list[Transform]] | None = None self._patch_stream_checked = False self.reset_augmentation() self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
[docs] def reset_augmentation(self): self.cache_attributes[:] = self.cache_attributes[:1] self.augmented_data.clear() self.total_augmentations = 0 i = 1 for data_augmentations in self.data_augmentations_list: shape = [] caches_attribute = [] for _ in range(data_augmentations.nb): shape.append(self.shape) caches_attribute.append(copy.deepcopy(self.cache_attributes[0])) for data_augmentation in data_augmentations.data_augmentations: shape = data_augmentation.state_init(self.index, shape, caches_attribute) for it, s in enumerate(shape): self.cache_attributes.append(caches_attribute[it]) self.patch.load(s, i) i += 1 self.total_augmentations += data_augmentations.nb self.augmentationLoaded = self.total_augmentations == 0
[docs] def load( self, pre_transform: list[Transform], data_augmentations_list: list[DataAugmentationsList], load_augmentations: bool = True, ) -> None: if not self.loaded: self._load(pre_transform) if load_augmentations and not self.augmentationLoaded: self._load_augmentation(data_augmentations_list)
def _load(self, pre_transform: list[Transform]): self.cache_attributes = copy.deepcopy(self.cache_attributes_bak) i = len(pre_transform) data = None for transform_function in reversed(pre_transform): if isinstance(transform_function, Save): if transform_function.dataset: if len(transform_function.dataset.split(":")) > 1: filename, file_format = transform_function.dataset.split(":") else: filename = transform_function.dataset file_format = "mha" dataset = Dataset(filename, file_format) else: dataset = self.dataset group_dest = transform_function.group if transform_function.group else self.group_dest if dataset.is_dataset_exist(group_dest, self.name): data, attrib = dataset.read_data(group_dest, self.name) self.cache_attributes[0].update(attrib) break i -= 1 if i == 0: data, _ = self.dataset.read_data(self.group_src, self.name) data = torch.from_numpy(data) if len(pre_transform): for transform_function in pre_transform[i:]: data = transform_function(self.name, data, self.cache_attributes[0]) if isinstance(transform_function, Save): if transform_function.dataset: if len(transform_function.dataset.split(":")) > 1: filename, file_format = transform_function.dataset.split(":") else: filename = transform_function.dataset file_format = "mha" dataset = Dataset(filename, file_format) else: dataset = self.dataset group_dest = transform_function.group if transform_function.group else self.group_dest dataset.write( group_dest, self.name, data.numpy(), self.cache_attributes[0], ) self.data.append(data) for i in range(len(self.cache_attributes) - 1): self.cache_attributes[i + 1].update(self.cache_attributes[0]) self.loaded = True def _load_augmentation(self, data_augmentations_list: list[DataAugmentationsList]) -> None: start_index = 1 for data_augmentations in data_augmentations_list: self._load_augmentation_group(start_index, data_augmentations) start_index += data_augmentations.nb self.augmentationLoaded = len(self.augmented_data) == self.total_augmentations def _load_augmentation_group(self, start_index: int, data_augmentations: DataAugmentationsList) -> None: if data_augmentations.nb == 0: return indices = range(start_index, start_index + data_augmentations.nb) if all(index in self.augmented_data for index in indices): return a_data = [self.data[0].clone() for _ in range(data_augmentations.nb)] for data_augmentation in data_augmentations.data_augmentations: if data_augmentation.groups is None or self.group_dest in data_augmentation.groups: a_data = data_augmentation(self.name, self.index, a_data) for index, data in zip(indices, a_data): self.augmented_data[index] = data self.augmentationLoaded = len(self.augmented_data) == self.total_augmentations def _get_tensor(self, a: int) -> torch.Tensor: if a == 0: return self.data[0] if a not in self.augmented_data: start_index = 1 for data_augmentations in self.data_augmentations_list: stop_index = start_index + data_augmentations.nb if start_index <= a < stop_index: self._load_augmentation_group(start_index, data_augmentations) break start_index = stop_index else: raise IndexError(f"Augmentation index {a} out of range for dataset '{self.name}'.") return self.augmented_data[a] @staticmethod def _required_stream_stats(transform: Transform) -> tuple[set[str], list[int] | None] | None: if isinstance(transform, Normalize): return {"Min", "Max"}, transform.channels if isinstance(transform, Standardize): required_stats = set() if transform.mean is None: required_stats.add("Mean") if transform.std is None: required_stats.add("Std") return required_stats, None if isinstance(transform, Clip): required_stats = set() if isinstance(transform.min_value, str): if transform.min_value != "min": return None required_stats.add("Min") if isinstance(transform.max_value, str): if transform.max_value != "max": return None required_stats.add("Max") return required_stats, None return None def _ensure_stream_stats( self, source_dataset: Dataset, source_group: str, cache_attribute: Attribute, required_stats: set[str], channels: list[int] | None = None, ) -> bool: missing_stats = [key for key in required_stats if key not in cache_attribute] if not missing_stats: return True stats = source_dataset.read_data_statistics(source_group, self.name, channels) stats_mapping = { "Min": stats["min"], "Max": stats["max"], "Mean": stats["mean"], "Std": stats["std"], } for key in missing_stats: if key in {"Mean", "Std"}: cache_attribute[key] = np.asarray([stats_mapping[key]], dtype=np.float32) else: cache_attribute[key] = stats_mapping[key] return all(key in cache_attribute for key in required_stats) def _supports_patch_stream_transform( self, transform: Transform, source_dataset: Dataset, source_group: str, cache_attribute: Attribute, ) -> bool: if isinstance(transform, TensorCast): return True if isinstance(transform, Clip) and transform.mask is not None: return False if isinstance(transform, Standardize) and transform.mask is not None: return False required_stream_stats = self._required_stream_stats(transform) if required_stream_stats is None: return False required_stats, channels = required_stream_stats return self._ensure_stream_stats(source_dataset, source_group, cache_attribute, required_stats, channels) @staticmethod def _dataset_from_spec(dataset_spec: str) -> Dataset: filename, _, file_format = split_path_spec( dataset_spec, default_format="mha", supported_extensions=SUPPORTED_EXTENSIONS, ) return Dataset(filename, file_format) def _resolve_patch_stream_source(self) -> tuple[Dataset, str, list[int], list[Transform]] | None: if self._patch_stream_checked: return self._patch_stream_source source_dataset = self.dataset source_group = self.group_src source_shape = self.base_shape trailing_transforms = self.transforms for index in range(len(self.transforms) - 1, -1, -1): transform = self.transforms[index] if isinstance(transform, Save): dataset = self._dataset_from_spec(transform.dataset) if transform.dataset else self.dataset group = transform.group if transform.group else self.group_dest if dataset.is_dataset_exist(group, self.name): source_dataset = dataset source_group = group source_shape, _ = dataset.get_infos(group, self.name) trailing_transforms = self.transforms[index + 1 :] break stream_cache_attribute = Attribute(self.cache_attributes[0]) if len(self.data_augmentations_list) == 0 and all( self._supports_patch_stream_transform( transform, source_dataset, source_group, stream_cache_attribute, ) for transform in trailing_transforms ): self.cache_attributes[0] = Attribute(stream_cache_attribute) self.cache_attributes_bak[0] = Attribute(stream_cache_attribute) self._patch_stream_source = (source_dataset, source_group, list(source_shape), trailing_transforms) else: self._patch_stream_source = None self._patch_stream_checked = True return self._patch_stream_source
[docs] def can_stream_patch(self, a: int) -> bool: return a == 0 and self._resolve_patch_stream_source() is not None
def _get_streamed_data(self, index: int, a: int, is_input: bool) -> tuple[torch.Tensor, Attribute]: stream_source = self._resolve_patch_stream_source() if stream_source is None: raise RuntimeError("Patch streaming requested on a dataset manager without a streaming source.") source_dataset, source_group, source_shape, transforms = stream_source plan = self.patch.get_read_plan(source_shape, index, a, is_input) data, attributes = source_dataset.read_data_slice(source_group, self.name, plan.data_slices) tensor = self.patch.apply_read_plan(torch.from_numpy(data), plan) cache_attribute = Attribute(self.cache_attributes[a]) cache_attribute.update(attributes) for transform in transforms: tensor = transform(self.name, tensor, cache_attribute) return tensor, cache_attribute
[docs] def unload(self) -> None: self.data.clear() self.augmented_data.clear() self.loaded = False self.augmentationLoaded = self.total_augmentations == 0
[docs] def unload_augmentation(self) -> None: self.augmented_data.clear() self.augmentationLoaded = self.total_augmentations == 0
[docs] def get_data(self, index: int, a: int, patch_transforms: list[Transform], is_input: bool) -> torch.Tensor: if not self.loaded and self.can_stream_patch(a): data, _ = self._get_streamed_data(index, a, is_input) else: data = self.patch.get_data(self._get_tensor(a), index, a, is_input) for transform_function in patch_transforms: data = transform_function(self.name, data, self.cache_attributes[a]) return data
[docs] def get_size(self, a: int) -> int: return self.patch.get_size(a)