# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Compatibility facade for KonfAI utility helpers and lightweight array utilities."""
import importlib
import itertools
import os
import re
from types import ModuleType
import numpy as np
from konfai.utils.errors import DatasetManagerError
[docs]
def get_module(classpath: str, default_classpath: str) -> tuple[ModuleType, str]:
if len(classpath.split(":")) > 1:
module_name = ".".join(classpath.split(":")[:-1])
name = classpath.split(":")[-1]
else:
module_name = (
default_classpath + ("." if len(classpath.split(".")) > 2 else "") + ".".join(classpath.split(".")[:-1])
)
name = classpath.split(".")[-1]
previous_mode = os.environ.get("KONFAI_CONFIG_MODE")
os.environ["KONFAI_CONFIG_MODE"] = "Import"
try:
module = importlib.import_module(module_name)
finally:
if previous_mode is None:
os.environ.pop("KONFAI_CONFIG_MODE", None)
else:
os.environ["KONFAI_CONFIG_MODE"] = previous_mode
return module, name.split("/")[0]
[docs]
def get_patch_slices_from_nb_patch_per_dim(
patch_size_tmp: list[int],
nb_patch_per_dim: list[tuple[int, bool]],
overlap: int | None,
) -> list[tuple[slice, ...]]:
patch_slices = []
slices: list[list[slice]] = []
if overlap is None:
overlap = 0
patch_size = []
i = 0
for nb in nb_patch_per_dim:
if nb[1]:
patch_size.append(1)
else:
patch_size.append(patch_size_tmp[i])
i += 1
for dim, nb in enumerate(nb_patch_per_dim):
slices.append([])
for index in range(nb[0]):
start = (patch_size[dim] - overlap) * index
end = start + patch_size[dim]
slices[dim].append(slice(start, end))
for chunk in itertools.product(*slices):
patch_slices.append(tuple(chunk))
return patch_slices
[docs]
def get_patch_slices_from_shape(
patch_size: list[int], shape: list[int], overlap_tmp: int | None
) -> tuple[list[tuple[slice, ...]], list[tuple[int, bool]]]:
if patch_size is None or all(p == 0 for p in patch_size):
patch_size = shape
if len(shape) != len(patch_size):
raise DatasetManagerError(
f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
f"patch_size: {patch_size}",
f"shape: {shape}",
"Both must have the same number of dimensions (e.g., 3D patch for 3D volume).",
)
patch_slices = []
nb_patch_per_dim = []
slices: list[list[slice]] = []
if overlap_tmp is None:
size = [np.ceil(a / b) for a, b in zip(shape, patch_size)]
tmp = np.zeros(len(size), dtype=np.int_)
for i, s in enumerate(size):
if s > 1:
tmp[i] = np.mod(patch_size[i] - np.mod(shape[i], patch_size[i]), patch_size[i]) // (size[i] - 1)
overlap = tmp
else:
overlap = [overlap_tmp if size > 1 else 0 for size in patch_size]
for dim in range(len(shape)):
if overlap[dim] >= patch_size[dim]:
raise ValueError(
f"Overlap must be less than patch size, got overlap={overlap[dim]}",
f" ≥ patch_size={patch_size[dim]} at dim={dim}",
)
for dim in range(len(shape)):
slices.append([])
index = 0
while True:
start = (patch_size[dim] - overlap[dim]) * index
end = start + patch_size[dim]
if end >= shape[dim]:
end = shape[dim]
slices[dim].append(slice(start, end))
break
slices[dim].append(slice(start, end))
index += 1
nb_patch_per_dim.append((index + 1, patch_size[dim] == 1))
for chunk in itertools.product(*slices):
patch_slices.append(tuple(chunk))
return patch_slices, nb_patch_per_dim
SUPPORTED_EXTENSIONS = [
"mha",
"mhd", # MetaImage
"nii",
"nii.gz", # NIfTI
"nrrd",
"nrrd.gz", # NRRD
"gipl",
"gipl.gz", # GIPL
"hdr",
"img", # Analyze
"dcm", # DICOM (si GDCM activé)
"tif",
"tiff", # TIFF
"png",
"jpg",
"jpeg",
"bmp", # 2D formats
"h5",
"itk.txt",
"fcsv",
"xml",
"vtk",
"npy",
]
_WINDOWS_ABSOLUTE_PATH_RE = re.compile(r"^[A-Za-z]:[\\/]")
[docs]
def is_windows_absolute_path(path: str) -> bool:
"""Return whether *path* looks like a Windows absolute path."""
return bool(_WINDOWS_ABSOLUTE_PATH_RE.match(path))
[docs]
def split_path_spec(
value: str,
*,
default_format: str = "mha",
allowed_flags: set[str] | None = None,
supported_extensions: list[str] | None = None,
) -> tuple[str, str | None, str]:
"""Split a KonfAI ``path[:flag]:format`` spec without breaking Windows paths.
KonfAI accepts dataset-like strings such as:
- ``./Dataset``
- ``./Dataset:mha``
- ``./Dataset:a:mha``
- ``C:\\Data\\Dataset:mha``
- ``C:\\Data\\Dataset:a:mha``
Parsing is performed from the right so the drive separator in Windows paths
is preserved.
"""
extensions = SUPPORTED_EXTENSIONS if supported_extensions is None else supported_extensions
parts = value.rsplit(":", 2)
if len(parts) == 1:
return value, None, default_format
if len(parts) == 2:
path, maybe_format = parts
if maybe_format in extensions:
return path, None, maybe_format
if is_windows_absolute_path(value):
return value, None, default_format
return path, None, maybe_format
path, middle, file_format = parts
if file_format in extensions:
if allowed_flags is not None and middle in allowed_flags:
return path, middle, file_format
return f"{path}:{middle}", None, file_format
return path, middle, file_format