# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Tensor and image transforms used in KonfAI preprocessing and postprocessing."""
import os
import tempfile
from abc import ABC, abstractmethod
from multiprocessing import current_process, get_context
from pathlib import Path
from typing import Any
import numpy as np
import SimpleITK as sitk # noqa: N813
import torch
import torch.nn.functional as F # noqa: N812
from konfai import cuda_visible_devices
from konfai.utils.config import apply_config
from konfai.utils.dataset import Attribute, Dataset, data_to_image, image_to_data
from konfai.utils.errors import TransformError
from konfai.utils.ITK import box_with_mask, crop_with_mask
from konfai.utils.runtime import NeedDevice
from konfai.utils.utils import get_module, split_path_spec
[docs]
class Clip(Transform):
"""Clip tensor intensities to a fixed or data-dependent value range."""
def __init__(
self,
min_value: float | str = -1024,
max_value: float | str = 1024,
save_clip_min: bool = False,
save_clip_max: bool = False,
mask: str | None = None,
) -> None:
super().__init__()
if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
raise ValueError(
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
)
self.min_value = min_value
self.max_value = max_value
self.save_clip_min = save_clip_min
self.save_clip_max = save_clip_max
self.mask = mask
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
mask = None
if self.mask is not None:
for dataset in self.datasets:
if dataset.is_dataset_exist(self.mask, name):
mask, _ = dataset.read_data(self.mask, name)
break
if mask is None and self.mask is not None:
raise ValueError(
f"Requested mask '{self.mask}' is not present in any dataset. "
"Check your dataset group names or configuration."
)
if mask is None:
tensor_masked = tensor
else:
tensor_masked = tensor[mask == 1]
if isinstance(self.min_value, str):
if self.min_value == "min":
min_value = torch.min(tensor_masked)
elif self.min_value.startswith("percentile:"):
try:
percentile = float(self.min_value.split(":")[1])
min_value = np.percentile(tensor_masked, percentile)
except (IndexError, ValueError):
raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
else:
raise TypeError(
f"Unsupported string for min_value: '{self.min_value}'."
"Must be a float, 'min', or 'percentile:<float>'."
)
else:
min_value = self.min_value
if isinstance(self.max_value, str):
if self.max_value == "max":
max_value = torch.max(tensor_masked)
elif self.max_value.startswith("percentile:"):
try:
percentile = float(self.max_value.split(":")[1])
max_value = np.percentile(tensor_masked, percentile)
except (IndexError, ValueError):
raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
else:
raise TypeError(
f"Unsupported string for max_value: '{self.max_value}'."
" Must be a float, 'max', or 'percentile:<float>'."
)
else:
max_value = self.max_value
tensor[torch.where(tensor.float() < min_value)] = min_value
tensor[torch.where(tensor.float() > max_value)] = max_value
if self.save_clip_min:
cache_attribute["Min"] = min_value
if self.save_clip_max:
cache_attribute["Max"] = max_value
return tensor
[docs]
class Normalize(TransformInverse):
"""Map intensities to a target min/max interval and optionally invert it."""
def __init__(
self,
lazy: bool = False,
channels: list[int] | None = None,
min_value: float = -1,
max_value: float = 1,
inverse: bool = True,
) -> None:
super().__init__(inverse)
if max_value <= min_value:
raise ValueError(
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
)
self.lazy = lazy
self.min_value = min_value
self.max_value = max_value
self.channels = channels
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if "Min" not in cache_attribute:
if self.channels:
cache_attribute["Min"] = torch.min(tensor[self.channels])
else:
cache_attribute["Min"] = torch.min(tensor)
if "Max" not in cache_attribute:
if self.channels:
cache_attribute["Max"] = torch.max(tensor[self.channels])
else:
cache_attribute["Max"] = torch.max(tensor)
if not self.lazy:
input_min = float(cache_attribute["Min"])
input_max = float(cache_attribute["Max"])
norm = input_max - input_min
if norm == 0:
print(f"[WARNING] Norm is zero for case '{name}': input is constant with value = {self.min_value}.")
if self.channels:
for channel in self.channels:
tensor[channel].fill_(self.min_value)
else:
tensor.fill_(self.min_value)
else:
if self.channels:
for channel in self.channels:
tensor[channel] = (self.max_value - self.min_value) * (
tensor[channel] - input_min
) / norm + self.min_value
else:
tensor = (self.max_value - self.min_value) * (tensor - input_min) / norm + self.min_value
return tensor
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if self.lazy:
return tensor
else:
input_min = float(cache_attribute.pop("Min"))
input_max = float(cache_attribute.pop("Max"))
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
[docs]
class UnNormalize(Transform):
def __init__(self, min_value: int = -1024, max_value: int = 3071) -> None:
super().__init__()
self.min_value = min_value
self.max_value = max_value
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return (tensor + 1) / 2 * (self.max_value - self.min_value) + self.min_value
[docs]
class Standardize(TransformInverse):
"""Standardize tensors using cached or computed mean and standard deviation."""
def __init__(
self,
lazy: bool = False,
mean: list[float] | None = None,
std: list[float] | None = None,
mask: str | None = None,
inverse: bool = True,
) -> None:
super().__init__(inverse)
self.lazy = lazy
self.mean = mean
self.std = std
self.mask = mask
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
mask = None
if self.mask is not None:
for dataset in self.datasets:
if dataset.is_dataset_exist(self.mask, name):
mask, _ = dataset.read_data(self.mask, name)
break
if mask is None and self.mask is not None:
raise ValueError(
f"Requested mask '{self.mask}' is not present in any dataset."
" Check your dataset group names or configuration."
)
if mask is None:
tensor_masked = tensor
else:
tensor_masked = tensor[mask == 1]
if "Mean" not in cache_attribute:
cache_attribute["Mean"] = (
torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
if self.mean is None
else torch.tensor([self.mean])
)
if "Std" not in cache_attribute:
cache_attribute["Std"] = (
torch.tensor([torch.std(tensor_masked.type(torch.float32))])
if self.std is None
else torch.tensor([self.std])
)
if self.lazy:
return tensor
else:
mean = cache_attribute.get_tensor("Mean")
std = cache_attribute.get_tensor("Std")
return (tensor - mean) / std
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if self.lazy:
return tensor
else:
mean = cache_attribute.pop_tensor("Mean")
std = cache_attribute.pop_tensor("Std")
return tensor * std + mean
[docs]
class TensorCast(TransformInverse):
def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
super().__init__(inverse)
self.dtype: torch.dtype = getattr(torch, dtype)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
cache_attribute["dtype"] = str(tensor.dtype).replace("torch.", "")
return tensor.type(self.dtype)
[docs]
@staticmethod
def safe_dtype_cast(dtype_str: str) -> torch.dtype:
try:
return getattr(torch, dtype_str)
except AttributeError:
raise ValueError(f"Unsupported dtype: {dtype_str}")
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
[docs]
class Padding(TransformInverse):
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
super().__init__(inverse)
self.padding = padding
self.mode = mode
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
origin = torch.tensor(cache_attribute.get_np_array("Origin"))
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
origin = torch.matmul(origin, matrix)
for dim in range(len(self.padding) // 2):
origin[-dim - 1] -= self.padding[dim * 2] * cache_attribute.get_np_array("Spacing")[-dim - 1]
cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
result = F.pad(
tensor.unsqueeze(0),
tuple(self.padding),
self.mode.split(":")[0],
float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0,
).squeeze(0)
return result
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
cache_attribute.pop("Origin")
slices = [slice(0, shape) for shape in tensor.shape]
for dim in range(len(self.padding) // 2):
slices[-dim - 1] = slice(self.padding[dim * 2], tensor.shape[-dim - 1] - self.padding[dim * 2 + 1])
result = tensor[tuple(slices)]
return result
[docs]
class Squeeze(TransformInverse):
def __init__(self, dim: int, inverse: bool = True) -> None:
super().__init__(inverse)
self.dim = dim
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.squeeze(self.dim)
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
return tensor.unsqueeze(self.dim)
[docs]
class Resample(TransformInverse, ABC):
def __init__(self, inverse: bool) -> None:
super().__init__(inverse)
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
if tensor.dtype == torch.uint8:
mode = "nearest"
elif len(tensor.shape) < 4:
mode = "bilinear"
else:
mode = "trilinear"
return (
F.interpolate(tensor.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode)
.squeeze(0)
.type(tensor.dtype)
.cpu()
)
@abstractmethod
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
pass
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
cache_attribute.pop_np_array("Size")
size_1 = cache_attribute.pop_np_array("Size")
_ = cache_attribute.pop_np_array("Spacing")
return self._resample(tensor, [int(size) for size in size_1])
[docs]
class ResampleToResolution(Resample):
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
super().__init__(inverse)
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
image_spacing = cache_attribute.get_tensor("Spacing")
spacing = self.spacing
resize_factor = torch.tensor(
[s / i_s if s > 0 else 1.0 for s, i_s in zip(self.spacing, cache_attribute.get_tensor("Spacing"))]
)
cache_attribute["Spacing"] = torch.tensor(
[float(s) if s > 0 else float(i_s) for s, i_s in zip(spacing, image_spacing)]
)
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor.flip(0))]
cache_attribute["Size"] = np.asarray(size)
return self._resample(tensor, size)
[docs]
class ResampleToShape(Resample):
def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
super().__init__(inverse)
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
shape = self.shape
image_shape = torch.tensor([int(x) for x in torch.tensor(tensor.shape[1:])])
for i, s in enumerate(self.shape):
if s == 0:
shape[i] = image_shape[i]
if "Spacing" in cache_attribute:
cache_attribute["Spacing"] = torch.flip(
image_shape / shape * torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]),
dims=[0],
)
cache_attribute["Size"] = image_shape
cache_attribute["Size"] = shape
return self._resample(tensor, shape)
[docs]
class Mask(Transform):
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
super().__init__()
self.path = path
self.value_outside = value_outside
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if self.path.endswith(".mha"):
mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
else:
mask = None
for dataset in self.datasets:
if dataset.is_dataset_exist(self.path, name):
mask, _ = dataset.read_data(self.path, name)
break
if mask is None:
raise NameError(f"Mask : {self.path}/{name} not found")
tensor[torch.tensor(mask) == 0] = self.value_outside
return tensor
[docs]
class Dilate(Transform):
def __init__(self, dilate: int = 1) -> None:
super().__init__()
if dilate < 0:
raise ValueError(f"[Dilate] 'dilate' must be >= 0, got {dilate}")
self.dilate = dilate
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if self.dilate == 0:
return tensor
data = (tensor > 0).to(torch.float32)
spatial_dims = data.dim() - 1
kernel_size = 2 * self.dilate + 1
if spatial_dims == 2:
data = F.max_pool2d(data, kernel_size=kernel_size, stride=1, padding=self.dilate)
elif spatial_dims == 3:
data = F.max_pool3d(data, kernel_size=kernel_size, stride=1, padding=self.dilate)
else:
raise ValueError(
"[Dilate] Unsupported tensor shape for "
f"'{name}': expected [C,H,W] or [C,D,H,W], got {list(tensor.shape)}"
)
return data.to(tensor.dtype)
[docs]
class Sum(Transform):
def __init__(self, dim: int = 0) -> None:
super().__init__()
self.dim = dim
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if "number_of_channels_per_model" in cache_attribute:
number_of_channels = cache_attribute.pop_tensor("number_of_channels_per_model")
result = tensor[0]
for i, t in enumerate(tensor[1:]):
t[t != 0] += int(number_of_channels[i]) - 1
result += t
return result
else:
return torch.sum(tensor, dim=self.dim).to(tensor.dtype)
[docs]
class Gradient(Transform):
def __init__(self, per_dim: bool = False):
super().__init__()
self.per_dim = per_dim
@staticmethod
def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
dx = image[:, 1:, :] - image[:, :-1, :]
dy = image[:, :, 1:] - image[:, :, :-1]
return torch.nn.ConstantPad2d((0, 0, 0, 1), 0)(dx), torch.nn.ConstantPad2d((0, 1, 0, 0), 0)(dy)
@staticmethod
def _image_gradient_3d(
image: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dx = image[:, 1:, :, :] - image[:, :-1, :, :]
dy = image[:, :, 1:, :] - image[:, :, :-1, :]
dz = image[:, :, :, 1:] - image[:, :, :, :-1]
return (
torch.nn.ConstantPad3d((0, 0, 0, 0, 0, 1), 0)(dx),
torch.nn.ConstantPad3d((0, 0, 0, 1, 0, 0), 0)(dy),
torch.nn.ConstantPad3d((0, 1, 0, 0, 0, 0), 0)(dz),
)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
result = torch.stack(
(Gradient._image_gradient_3d(tensor) if len(tensor.shape) == 4 else Gradient._image_gradient_2d(tensor)),
dim=1,
).squeeze(0)
if not self.per_dim:
result = torch.sigmoid(result * 3)
result = result.norm(dim=0)
result = torch.unsqueeze(result, 0)
return result
[docs]
class Argmax(Transform):
def __init__(self, dim: int = 0) -> None:
super().__init__()
self.dim = dim
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
[docs]
class Softmax(Transform):
def __init__(self, dim: int = 0) -> None:
super().__init__()
self.dim = dim
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return torch.softmax(tensor, dim=self.dim)
[docs]
class FlatLabel(Transform):
def __init__(self, labels: list[int] | None = None) -> None:
super().__init__()
self.labels = labels
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
data = torch.zeros_like(tensor)
if self.labels:
for label in self.labels:
data[torch.where(tensor == label)] = 1
else:
data[torch.where(tensor > 0)] = 1
return data
[docs]
class Save(Transform):
def __init__(self, dataset: str, group: str | None = None) -> None:
super().__init__()
self.dataset = dataset
self.group = group
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor
[docs]
class Flatten(Transform):
def __init__(self) -> None:
super().__init__()
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.flatten()
[docs]
class Permute(TransformInverse):
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
super().__init__(inverse)
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.permute(tuple(self.dims))
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.permute(tuple(np.argsort(self.dims)))
[docs]
class Flip(TransformInverse):
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
super().__init__(inverse)
self.dims = [int(d) + 1 for d in str(dims).split("|")]
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.flip(tuple(self.dims))
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensor.flip(tuple(self.dims))
[docs]
class Canonical(TransformInverse):
def __init__(self, inverse: bool = True) -> None:
super().__init__(inverse)
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
@staticmethod
def _affine_matrix(matrix: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
return torch.cat(
(
torch.cat((matrix, translation.unsqueeze(0).T), dim=1),
torch.tensor([[0, 0, 0, 1]]),
),
dim=0,
)
@staticmethod
def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
if data.dtype == torch.uint8:
mode = "nearest"
else:
mode = "bilinear"
return (
torch.nn.functional.grid_sample(
data.unsqueeze(0).type(torch.float32),
torch.nn.functional.affine_grid(
matrix[:, :-1, ...].type(torch.float32),
[1] + list(data.shape),
align_corners=True,
),
align_corners=True,
mode=mode,
padding_mode="reflection",
)
.squeeze(0)
.type(data.dtype)
)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
spacing = cache_attribute.get_tensor("Spacing")
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3, 3).to(torch.double)
initial_origin = cache_attribute.get_tensor("Origin")
cache_attribute["Direction"] = (self.canonical_direction).flatten()
matrix = Canonical._affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
center_voxel = torch.tensor(
[(tensor.shape[-i - 1] - 1) * spacing[i] / 2 for i in range(3)],
dtype=torch.double,
)
center_physical = initial_matrix @ center_voxel + initial_origin
cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
return Canonical._resample_affine(tensor, matrix.unsqueeze(0))
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
cache_attribute.pop("Direction")
cache_attribute.pop("Origin")
matrix = Canonical._affine_matrix(
(
self.canonical_direction
[docs]
@ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3, 3).inverse()
).inverse(),
torch.tensor([0, 0, 0]),
)
return Canonical._resample_affine(tensor, matrix.unsqueeze(0))
class HistogramMatching(Transform):
def __init__(self, reference_group: str) -> None:
super().__init__()
self.reference_group = reference_group
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
image = data_to_image(tensor, cache_attribute)
image_ref = None
for dataset in self.datasets:
if dataset.is_dataset_exist(self.reference_group, name):
image_ref = dataset.read_image(self.reference_group, name)
if image_ref is None:
raise NameError(f"Image : {self.reference_group}/{name} not found")
matcher = sitk.HistogramMatchingImageFilter()
matcher.SetNumberOfHistogramLevels(256)
matcher.SetNumberOfMatchPoints(1)
matcher.SetThresholdAtMeanIntensity(True)
result, _ = image_to_data(matcher.Execute(image, image_ref))
return torch.tensor(result)
[docs]
class SelectLabel(Transform):
def __init__(self, labels: list[str]) -> None:
super().__init__()
self.labels = [label[1:-1].split(",") for label in labels]
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
data = torch.zeros_like(tensor)
for old_label, new_label in self.labels:
data[tensor == int(old_label)] = int(new_label)
return data
[docs]
class OneHot(TransformInverse):
def __init__(self, num_classes: int, inverse: bool = True) -> None:
super().__init__(inverse)
self.num_classes = num_classes
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
result = (
F.one_hot(tensor.type(torch.int64), num_classes=self.num_classes)
.permute(0, len(tensor.shape), *[i + 1 for i in range(len(tensor.shape) - 1)])
.float()
.squeeze(0)
)
return result
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return torch.argmax(tensor, dim=1).unsqueeze(1)
[docs]
class KonfAIInference(Transform):
supports_dataloader_workers = False
def __init__(
self,
repo_id: str = "VBoussot/MRSegmentator-KonfAI",
model_name: str = "MRSegmentator",
checkpoints_name: list[str] = ["fold_0"],
number_of_tta: int = 0,
number_of_mc: int = 0,
per_channel: bool = False,
):
super().__init__()
self.repo_id = repo_id
self.model_name = model_name
self.checkpoints_name = checkpoints_name
self.number_of_tta = number_of_tta
self.number_of_mc = number_of_mc
self.per_channel = per_channel
[docs]
def infer_entry(self, dataset_path: Path, output_path: Path, gpu: list[int]):
try:
from konfai_apps import KonfAIApp
except ImportError as exc: # pragma: no cover - depends on optional install
raise RuntimeError(
"KonfAIInference requires the standalone 'konfai-apps' package. "
"Install it from the repository with 'pip install -e ./konfai-apps'."
) from exc
# Nested KonfAI runs must choose their own rendezvous ports instead of
# inheriting the parent's already-bound distributed settings.
os.environ.pop("KONFAI_MASTER_PORT", None)
os.environ.pop("KONFAI_TENSORBOARD_PORT", None)
konfai_app = KonfAIApp(f"{self.repo_id}:{self.model_name}", False, False)
konfai_app.infer(
[[dataset_path]],
output_path,
0,
self.checkpoints_name,
self.number_of_tta,
mc=0,
uncertainty=False,
gpu=gpu,
)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if current_process().daemon:
raise RuntimeError(
"KonfAIInference cannot run inside daemon DataLoader workers. "
"Use 'Dataset.num_workers: 0' for pipelines that include this transform."
)
with tempfile.TemporaryDirectory() as tmpdir:
dataset_path = Path(tmpdir) / "Dataset"
if self.per_channel:
for i, channel in enumerate(tensor):
image = data_to_image(channel.unsqueeze(0).numpy(), cache_attribute)
(dataset_path / f"P{i:03d}").mkdir(parents=True, exist_ok=True)
sitk.WriteImage(image, str(dataset_path / f"P{i:03d}" / "Volume.mha"))
else:
image = data_to_image(tensor.numpy(), cache_attribute)
(dataset_path / "P000").mkdir(parents=True, exist_ok=True)
sitk.WriteImage(image, str(dataset_path / "P000" / "Volume.mha"))
ctx = get_context("spawn")
p = ctx.Process(
target=self.infer_entry, args=(dataset_path, Path(tmpdir) / "Output", cuda_visible_devices())
)
p.start()
p.join()
if p.exitcode != 0:
raise RuntimeError("Inference process failed")
result = []
for file in (Path(tmpdir) / "Output").rglob("*.mha"):
if file.name != "InferenceStack.mha":
result.append(torch.from_numpy(image_to_data(sitk.ReadImage(str(file)))[0]))
return torch.stack(result, dim=1).squeeze(0)
[docs]
class InferenceStack(Transform):
def __init__(self, dataset: str, name: str, mode: str = "mean"):
self.dataset = None
if dataset:
filename, _, file_format = split_path_spec(dataset)
self.dataset = Dataset(filename, file_format)
self.name = name
self.mode = mode
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if tensors.shape[0] == 1:
return tensors.squeeze(0)
if self.mode == "Seg":
_tensors = torch.argmax(torch.softmax(tensors, dim=1), dim=1).to(torch.uint8)
else:
_tensors = tensors.squeeze(1)
dataset = self.dataset if self.dataset else self.datasets[-1]
dataset.write("InferenceStack", name, _tensors.float().cpu().numpy(), cache_attribute)
return (
torch.median(tensors.float(), dim=0).values.to(tensors.dtype)
if self.mode == "median"
else tensors.float().mean(0).to(tensors.dtype)
)
[docs]
class Variance(Transform):
def __init__(self) -> None:
super().__init__()
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensors.float().var(0).unsqueeze(0) if tensors.shape[0] > 1 else torch.zeros_like(tensors[0])
[docs]
class SegmentationDisagreement(Transform):
def __init__(self, ignore_background: bool = False) -> None:
super().__init__()
self.ignore_background = ignore_background
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
# tensors shape: [N, ...] with N segmentations and integer labels per voxel
if tensors.shape[0] <= 1:
return torch.zeros_like(tensors[0], dtype=torch.float32).unsqueeze(0)
tensors = tensors.long()
if self.ignore_background:
valid = tensors != 0
else:
valid = torch.ones_like(tensors, dtype=torch.bool)
disagreement = torch.zeros_like(tensors[0], dtype=torch.float32)
# per-voxel disagreement = 1 - (frequency of majority label / number of valid segmentations)
unique_labels = torch.unique(tensors)
counts = []
for label in unique_labels:
counts.append(((tensors == label) & valid).sum(dim=0))
counts = torch.stack(counts, dim=0) # [L, ...]
max_count = counts.max(dim=0).values
valid_count = valid.sum(dim=0)
non_empty = valid_count > 0
disagreement[non_empty] = 1.0 - (max_count[non_empty].float() / valid_count[non_empty].float())
return disagreement.unsqueeze(0)
[docs]
class Percentage(Transform):
def __init__(self, baseline: float) -> None:
super().__init__()
self.baseline = baseline
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensors / self.baseline * 100.0
[docs]
class StandardDeviation(Transform):
def __init__(self) -> None:
super().__init__()
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
return tensors.float().std(0).unsqueeze(0) if tensors.shape[0] > 1 else torch.zeros_like(tensors[0])
[docs]
class Statistics(Transform):
def __init__(self) -> None:
super().__init__()
def __call__(self, name: str, tensors: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
cache_attribute["ImageMin"] = tensors.float().min()
cache_attribute["ImageMax"] = tensors.float().max()
cache_attribute["ImageMean"] = tensors.float().mean()
cache_attribute["ImageStd"] = tensors.float().std()
return tensors
[docs]
class Crop(TransformInverse):
def __init__(self, inverse: bool = True) -> None:
super().__init__(inverse)
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if "box" not in cache_attribute:
return tensor
box_str = cache_attribute["box"]
flat = np.fromstring(box_str.replace("[", " ").replace("]", " "), sep=" ", dtype=np.int64)
box = flat.reshape(-1, 2)
for i, ((_, b), s) in enumerate(zip(box, tensor.shape[1:])):
box[i][1] = s - b
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
origin = torch.tensor(cache_attribute.get_np_array("Origin"))
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
origin = torch.matmul(origin, matrix)
for dim in range(box.shape[0]):
origin[-dim - 1] += box[dim][0] * cache_attribute.get_np_array("Spacing")[-dim - 1]
cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
image = data_to_image(tensor.numpy(), cache_attribute)
result = crop_with_mask(image, box)
data, _ = image_to_data(result)
return torch.from_numpy(data)
[docs]
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
if "box" not in cache_attribute:
return tensor
box_str = cache_attribute.pop("box")
flat = np.fromstring(box_str.replace("[", " ").replace("]", " "), sep=" ", dtype=np.int64)
box = flat.reshape(-1, 2)
cache_attribute.pop_np_array("Origin")
padding = []
for b in reversed(box):
padding.extend([b[0], b[1]])
result = F.pad(tensor.unsqueeze(0), tuple(padding), "replicate").squeeze(0)
return result