Source code for konfai.utils.dataset

# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Dataset file abstractions and image conversion utilities for KonfAI."""

import ast
import copy
import csv
import glob
import math
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any

import h5py
import numpy as np
import SimpleITK as sitk  # noqa: N813
import torch
from lxml import etree  # nosec B410

from konfai import current_date
from konfai.utils.utils import SUPPORTED_EXTENSIONS


[docs] class Attribute(dict[str, Any]): """Metadata container storing repeated values with a stack-like naming scheme.""" def __init__(self, attributes: dict[str, Any] | None = None) -> None: super().__init__() attributes = attributes or {} for k, v in attributes.items(): super().__setitem__(copy.deepcopy(k), copy.deepcopy(v)) def __getitem__(self, key: str) -> Any: i = len([k for k in super().keys() if k.startswith(key)]) if i > 0 and f"{key}_{i - 1}" in super().keys(): return str(super().__getitem__(f"{key}_{i - 1}")) else: raise NameError(f"{key} not in cache_attribute") def __setitem__(self, key: str, value: Any) -> None: if "_" not in key: i = len([k for k in super().keys() if k.startswith(key)]) result = None if isinstance(value, torch.Tensor): result = str(value.numpy()) else: result = str(value) result = result.replace("\n", "") super().__setitem__(f"{key}_{i}", result) else: result = None if isinstance(value, torch.Tensor): result = str(value.numpy()) else: result = str(value) result = result.replace("\n", "") super().__setitem__(key, result)
[docs] def pop(self, key: str, default: Any = None) -> Any: i = len([k for k in super().keys() if k.startswith(key)]) if i > 0 and f"{key}_{i - 1}" in super().keys(): return super().pop(f"{key}_{i - 1}") else: raise NameError(f"{key} not in cache_attribute")
[docs] def get_np_array(self, key) -> np.ndarray: return np.fromstring(self[key][1:-1], sep=" ", dtype=np.double)
[docs] def get_tensor(self, key) -> torch.Tensor: return torch.tensor(self.get_np_array(key)).to(torch.float32)
[docs] def pop_np_array(self, key): return np.fromstring(self.pop(key)[1:-1], sep=" ", dtype=np.double)
[docs] def pop_tensor(self, key) -> torch.Tensor: return torch.tensor(self.pop_np_array(key))
def __contains__(self, key: object) -> bool: if not isinstance(key, str): return False return any(k.startswith(key) for k in super().keys())
[docs] def is_info(self, key: str, value: str) -> bool: return key in self and self[key] == value
def _update_running_statistics( state: dict[str, float] | None, array: np.ndarray, ) -> dict[str, float]: """Update running min/max/mean/std statistics from a NumPy chunk.""" values = np.asarray(array, dtype=np.float64).reshape(-1) if values.size == 0: return state or {"count": 0.0, "mean": 0.0, "m2": 0.0, "min": np.inf, "max": -np.inf} if state is None: state = {"count": 0.0, "mean": 0.0, "m2": 0.0, "min": np.inf, "max": -np.inf} chunk_count = float(values.size) chunk_mean = float(values.mean()) chunk_m2 = float(np.square(values - chunk_mean).sum()) total_count = state["count"] + chunk_count delta = chunk_mean - state["mean"] if total_count > 0: state["mean"] += delta * chunk_count / total_count state["m2"] += chunk_m2 + delta * delta * state["count"] * chunk_count / total_count state["count"] = total_count state["min"] = min(state["min"], float(values.min())) state["max"] = max(state["max"], float(values.max())) return state def _finalize_running_statistics(state: dict[str, float] | None) -> dict[str, float]: """Convert a running-statistics state into the public stats dictionary.""" if state is None or state["count"] == 0: return {"min": 0.0, "max": 0.0, "mean": 0.0, "std": 0.0} variance = state["m2"] / (state["count"] - 1) if state["count"] > 1 else 0.0 return { "min": state["min"], "max": state["max"], "mean": state["mean"], "std": math.sqrt(max(variance, 0.0)), }
[docs] def is_an_image(attributes: Attribute): """Return whether the given attribute set contains image geometry metadata.""" return "Origin" in attributes and "Spacing" in attributes and "Direction" in attributes
[docs] def data_to_image(data: np.ndarray, attributes: Attribute) -> sitk.Image: """Convert a NumPy array and KonfAI attributes into a SimpleITK image.""" if not is_an_image(attributes): raise NameError("Data is not an image") if data.shape[0] == 1: image = sitk.GetImageFromArray(data[0]) else: data = data.transpose(tuple([i + 1 for i in range(len(data.shape) - 1)] + [0])) image = sitk.GetImageFromArray(data, isVector=True) for k, v in attributes.items(): if v and len(v): image.SetMetaData(k, v) image.SetOrigin(attributes.get_np_array("Origin").tolist()) image.SetSpacing(attributes.get_np_array("Spacing").tolist()) image.SetDirection(attributes.get_np_array("Direction").tolist()) return image
[docs] def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]: """Convert a SimpleITK image into a channel-first NumPy array and attributes.""" attributes = Attribute() attributes["Origin"] = np.asarray(image.GetOrigin()) attributes["Spacing"] = np.asarray(image.GetSpacing()) attributes["Direction"] = np.asarray(image.GetDirection()) for k in image.GetMetaDataKeys(): attributes[k] = image.GetMetaData(k) data = sitk.GetArrayFromImage(image) if image.GetNumberOfComponentsPerPixel() == 1: data = np.expand_dims(data, 0) else: data = np.transpose(data, (len(data.shape) - 1, *list(range(len(data.shape) - 1)))) return data, attributes
[docs] def get_infos(filename: str | Path) -> tuple[list[int], Attribute]: """Read shape and metadata from an image file without loading its full pixel data.""" attributes = Attribute() file_reader = sitk.ImageFileReader() file_reader.SetFileName(str(filename)) file_reader.ReadImageInformation() attributes["Origin"] = np.asarray(file_reader.GetOrigin()) attributes["Spacing"] = np.asarray(file_reader.GetSpacing()) attributes["Direction"] = np.asarray(file_reader.GetDirection()) for k in file_reader.GetMetaDataKeys(): attributes[k] = file_reader.GetMetaData(k) size = list(file_reader.GetSize()) if len(size) == 3: size = list(reversed(size)) size = [file_reader.GetNumberOfComponents()] + size return size, attributes
[docs] def read_landmarks(filename: Path) -> np.ndarray | None: """Read Slicer-style fiducial landmarks from disk.""" data = None with open(filename, newline="") as csvfile: reader = csv.reader(filter(lambda row: row[0] != "#", csvfile)) lines = list(reader) data = np.zeros((len(list(lines)), 3), dtype=np.double) for i, row in enumerate(lines): data[i] = np.array(row[1:4], dtype=np.double) csvfile.close() return data
[docs] def write_landmarks(data: np.ndarray, filename: Path) -> None: """Write landmarks to the Slicer Markups fiducial CSV-like format.""" with open(filename, "w") as f: f.write( "# Markups fiducial file version = 4.6\n# CoordinateSystem = LPS\n#" " columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n", ) for i in range(data.shape[0]): f.write( "vtkMRMLMarkupsFiducialNode_" + str(i + 1) + "," + str(data[i, 0]) + "," + str(data[i, 1]) + "," + str(data[i, 2]) + ",0,0,0,1,1,1,0,F-" + str(i + 1) + ",,vtkMRMLScalarVolumeNode1\n" ) f.close()
[docs] class Dataset: """Filesystem or HDF5-backed dataset abstraction used across KonfAI."""
[docs] class AbstractFile(ABC): @abstractmethod def __init__(self) -> None: pass @abstractmethod def __enter__(self): pass @abstractmethod def __exit__(self, exc_type, value, traceback): pass
[docs] @abstractmethod def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]: pass
[docs] @abstractmethod def file_to_data_slice(self, group: str, name: str, slices: tuple[slice, ...]) -> tuple[np.ndarray, Attribute]: pass
[docs] @abstractmethod def file_to_data_statistics( self, group: str, name: str, channels: list[int] | None = None, ) -> dict[str, float]: pass
[docs] @abstractmethod def data_to_file( self, name: str, data: sitk.Image | sitk.Transform | np.ndarray, attributes: Attribute | None = None, ) -> None: pass
[docs] @abstractmethod def get_names(self, group: str) -> list[str]: pass
[docs] @abstractmethod def get_group(self) -> list[str]: pass
[docs] @abstractmethod def is_exist(self, group: str, name: str | None = None) -> bool: pass
[docs] @abstractmethod def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]: pass
[docs] class H5File(AbstractFile): def __init__(self, filename: str, read: bool) -> None: self.h5: h5py.File | None = None self.filename = filename if not self.filename.endswith(".h5"): self.filename += ".h5" self.read = read def __enter__(self): if self.read: self.h5 = h5py.File(self.filename, "r") else: if not os.path.exists(self.filename): if len(self.filename.split("/")) > 1 and not os.path.exists( "/".join(self.filename.split("/")[:-1]) ): os.makedirs("/".join(self.filename.split("/")[:-1])) self.h5 = h5py.File(self.filename, "w") else: self.h5 = h5py.File(self.filename, "r+") self.h5.attrs["Date"] = current_date() self.h5.__enter__() return self.h5 def __exit__(self, exc_type, value, traceback): if self.h5 is not None: self.h5.close()
[docs] def file_to_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]: dataset = self._get_dataset(groups, name) data = np.zeros(dataset.shape, dataset.dtype) dataset.read_direct(data) return data, Attribute({k: str(v) for k, v in dataset.attrs.items()})
[docs] def file_to_data_slice(self, groups: str, name: str, slices: tuple[slice, ...]) -> tuple[np.ndarray, Attribute]: dataset = self._get_dataset(groups, name) data = np.asarray(dataset[slices]) return data, Attribute({k: str(v) for k, v in dataset.attrs.items()})
[docs] def file_to_data_statistics( self, groups: str, name: str, channels: list[int] | None = None, ) -> dict[str, float]: dataset = self._get_dataset(groups, name) if dataset is None: raise NameError(f"Dataset '{groups}/{name}' not found in '{self.filename}'.") axis = 1 if dataset.ndim > 1 else 0 trailing_size = int(np.prod(dataset.shape[axis + 1 :], dtype=np.int64)) if axis + 1 < dataset.ndim else 1 max_elements = 8_000_000 chunk_length = max(1, max_elements // max(1, trailing_size)) state: dict[str, float] | None = None for start in range(0, dataset.shape[axis], chunk_length): slices = [slice(None)] * dataset.ndim slices[axis] = slice(start, min(dataset.shape[axis], start + chunk_length)) chunk = np.asarray(dataset[tuple(slices)]) if channels is not None: chunk = chunk[channels] state = _update_running_statistics(state, chunk) return _finalize_running_statistics(state)
[docs] def data_to_file( self, name: str, data: sitk.Image | sitk.Transform | np.ndarray, attributes: Attribute | None = None, ) -> None: if self.h5 is None: return if attributes is None: attributes = Attribute() if isinstance(data, sitk.Image): data, attributes_tmp = image_to_data(data) attributes.update(attributes_tmp) elif isinstance(data, sitk.Transform): transforms = [] if isinstance(data, sitk.CompositeTransform): for i in range(data.GetNumberOfTransforms()): transforms.append(data.GetNthTransform(i)) else: transforms.append(data) datas = [] for i, transform in enumerate(transforms): if isinstance(transform, sitk.Euler3DTransform): transform_type = "Euler3DTransform_double_3_3" if isinstance(transform, sitk.AffineTransform): transform_type = "AffineTransform_double_3_3" if isinstance(transform, sitk.BSplineTransform): transform_type = "BSplineTransform_double_3_3" attributes[f"{i}:Transform"] = transform_type attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters() datas.append(np.asarray(transform.GetParameters())) data = np.asarray(datas) h5_group = self.h5 if len(name.split("/")) > 1: group = "/".join(name.split("/")[:-1]) if group not in self.h5: self.h5.create_group(group) h5_group = self.h5[group] name = name.split("/")[-1] if name in h5_group: del h5_group[name] dataset = h5_group.create_dataset(name, data=data, dtype=data.dtype, chunks=None) dataset.attrs.update({k: str(v) for k, v in attributes.items()})
[docs] def is_exist(self, group: str, name: str | None = None) -> bool: if self.h5 is not None: if group in self.h5: if isinstance(self.h5[group], h5py.Dataset): return True elif name is not None: return name in self.h5[group] else: return False return False
[docs] def get_names(self, groups: str, h5_group: h5py.Group = None) -> list[str]: names = [] if h5_group is None: h5_group = self.h5 group = groups.split("/")[0] if group == "": names = [ dataset.name.split("/")[-1] for dataset in h5_group.values() if isinstance(dataset, h5py.Dataset) ] elif group == "*": for k in h5_group.keys(): if isinstance(h5_group[k], h5py.Group): names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[k])) else: if group in h5_group: names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[group])) return names
[docs] def get_group(self) -> list[str]: return list(self.h5.keys()) if self.h5 is not None else []
def _get_dataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset: if h5_group is None: h5_group = self.h5 if groups != "": group = groups.split("/")[0] else: group = "" result = None if group == "": if name in h5_group: result = h5_group[name] elif group == "*": for k in h5_group.keys(): if isinstance(h5_group[k], h5py.Group): result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[k]) if result_tmp is not None: result = result_tmp else: if group in h5_group: result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[group]) if result_tmp is not None: result = result_tmp return result
[docs] def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]: dataset = self._get_dataset(groups, name) return ( dataset.shape, Attribute({k: str(v) for k, v in dataset.attrs.items()}), )
[docs] class SitkFile(AbstractFile): def __init__(self, filename: str, read: bool, file_format: str) -> None: self.filename = filename self.read = read self.file_format = file_format @staticmethod def _normalize_slices(slices: tuple[slice, ...], shape: list[int]) -> tuple[slice, ...]: if len(slices) != len(shape): raise ValueError(f"Expected {len(shape)} slices, got {len(slices)}.") normalized = [] for item, size in zip(slices, shape): start, stop, step = item.indices(size) normalized.append(slice(start, stop, step)) return tuple(normalized) @staticmethod def _supports_direct_slice(slices: tuple[slice, ...]) -> bool: return all(item.step in (None, 1) for item in slices) def _resolve_data_path(self, name: str) -> str | None: base = f"{self.filename}{name}" direct = f"{base}.{self.file_format}" if os.path.exists(direct): return direct for suffix in (".itk.txt", ".fcsv", ".xml", ".vtk", ".npy"): candidate = f"{base}{suffix}" if os.path.exists(candidate): return candidate matches = glob.glob(f"{base}.*") return matches[0] if matches else None def _file_to_image_slice(self, name: str, path: str, slices: tuple[slice, ...]) -> tuple[np.ndarray, Attribute]: reader = sitk.ImageFileReader() reader.SetFileName(path) reader.ReadImageInformation() spatial_size_xyz = list(reader.GetSize()) spatial_shape = list(reversed(spatial_size_xyz)) data_shape = [reader.GetNumberOfComponents()] + spatial_shape normalized = self._normalize_slices(slices, data_shape) if not self._supports_direct_slice(normalized): data, attributes = self.file_to_data("", name) return data[normalized], attributes extract_index_xyz = [item.start for item in reversed(normalized[1:])] extract_size_xyz = [item.stop - item.start for item in reversed(normalized[1:])] reader.SetExtractIndex(extract_index_xyz) reader.SetExtractSize(extract_size_xyz) image = reader.Execute() data, attributes = image_to_data(image) origin = np.asarray(reader.GetOrigin(), dtype=np.float64) spacing = np.asarray(reader.GetSpacing(), dtype=np.float64) direction = np.asarray(reader.GetDirection(), dtype=np.float64).reshape(len(spacing), len(spacing)) attributes["Origin"] = origin + direction @ (np.asarray(extract_index_xyz, dtype=np.float64) * spacing) return data[normalized[:1] + tuple(slice(None) for _ in normalized[1:])], attributes
[docs] def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]: attributes = Attribute() if os.path.exists(f"{self.filename}{name}.itk.txt"): data = sitk.ReadTransform(f"{self.filename}{name}.itk.txt") transforms = [] if isinstance(data, sitk.CompositeTransform): for i in range(data.GetNumberOfTransforms()): transforms.append(data.GetNthTransform(i)) else: transforms.append(data) datas = [] for i, transform in enumerate(transforms): if isinstance(transform, sitk.Euler3DTransform): transform_type = "Euler3DTransform_double_3_3" if isinstance(transform, sitk.AffineTransform): transform_type = "AffineTransform_double_3_3" if isinstance(transform, sitk.BSplineTransform): transform_type = "BSplineTransform_double_3_3" attributes[f"{i}:Transform"] = transform_type attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters() datas.append(np.asarray(transform.GetParameters())) max_len = max(len(v) for v in datas) padded_datas = np.array([np.pad(v, (0, max_len - len(v)), constant_values=np.nan) for v in datas]) data = np.asarray(padded_datas) elif os.path.exists(f"{self.filename}{name}.fcsv"): data = read_landmarks(Path(f"{self.filename}{name}.fcsv")) elif os.path.exists(f"{self.filename}{name}.xml"): with open(f"{self.filename}{name}.xml", "rb") as xml_file: result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320 xml_file.close() return result elif os.path.exists(f"{self.filename}{name}.vtk"): import vtk vtk_reader = vtk.vtkPolyDataReader() vtk_reader.SetFileName(f"{self.filename}{name}.vtk") vtk_reader.Update() data = [] points = vtk_reader.GetOutput().GetPoints() num_points = points.GetNumberOfPoints() for i in range(num_points): data.append(list(points.GetPoint(i))) data = np.asarray(data) elif os.path.exists(f"{self.filename}{name}.npy"): data = np.load(f"{self.filename}{name}.npy") else: pattern = f"{self.filename}{name}.*" matches = glob.glob(pattern) if matches: path = matches[0] image = sitk.ReadImage(path) data, attributes_tmp = image_to_data(image) attributes.update(attributes_tmp) return data, attributes
[docs] def file_to_data_slice(self, group: str, name: str, slices: tuple[slice, ...]) -> tuple[np.ndarray, Attribute]: path = self._resolve_data_path(name) if path is None: raise NameError(f"Data '{name}' not found in dataset '{self.filename}'.") if path.endswith(".npy"): data = np.load(path, mmap_mode="r")[slices] return np.asarray(data), Attribute() if path.endswith((".itk.txt", ".fcsv", ".xml", ".vtk")): data, attributes = self.file_to_data(group, name) return data[slices], attributes return self._file_to_image_slice(name, path, slices)
[docs] def file_to_data_statistics( self, group: str, name: str, channels: list[int] | None = None, ) -> dict[str, float]: path = self._resolve_data_path(name) if path is None: raise NameError(f"Data '{name}' not found in dataset '{self.filename}'.") if path.endswith(".npy"): data = np.load(path, mmap_mode="r") if channels is not None: data = data[channels] return _finalize_running_statistics(_update_running_statistics(None, data)) if path.endswith((".itk.txt", ".fcsv", ".xml", ".vtk")): data, _ = self.file_to_data(group, name) if channels is not None: data = data[channels] return _finalize_running_statistics(_update_running_statistics(None, data)) image = sitk.ReadImage(path) data = sitk.GetArrayViewFromImage(image) if image.GetNumberOfComponentsPerPixel() == 1: data = np.expand_dims(data, 0) else: data = np.transpose(data, (len(data.shape) - 1, *list(range(len(data.shape) - 1)))) if channels is not None: data = data[channels] return _finalize_running_statistics(_update_running_statistics(None, data))
[docs] def is_vtk_polydata(self, obj): try: import vtk return isinstance(obj, vtk.vtkPolyData) except ImportError: return False
def __enter__(self): pass def __exit__(self, exc_type, value, traceback): pass
[docs] def data_to_file( self, name: str, data: sitk.Image | sitk.Transform | np.ndarray, attributes: Attribute | None = None, ) -> None: if attributes is None: attributes = Attribute() if not os.path.exists(self.filename): os.makedirs(self.filename) if isinstance(data, sitk.Image): for k, v in attributes.items(): if v and len(v): data.SetMetaData(k, v) sitk.WriteImage(data, f"{self.filename}{name}.{self.file_format}") elif isinstance(data, sitk.Transform): sitk.WriteTransform(data, f"{self.filename}{name}.itk.txt") elif self.is_vtk_polydata(data): import vtk vtk_writer = vtk.vtkPolyDataWriter() vtk_writer.SetFileName(f"{self.filename}{name}.vtk") vtk_writer.SetInputData(data) vtk_writer.Write() elif is_an_image(attributes): self.data_to_file(name, data_to_image(data, attributes), attributes) elif len(data.shape) == 2 and data.shape[1] == 3 and data.shape[0] > 0: data = np.round(data, 4) write_landmarks(data, Path(f"{self.filename}{name}.fcsv")) elif "path" in attributes: if os.path.exists(f"{self.filename}{name}.xml"): with open(f"{self.filename}{name}.xml", "rb") as xml_file: root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320 xml_file.close() else: root = etree.Element(name) node = root path = attributes["path"].split(":") for node_name in path: node_tmp = node.find(node_name) if node_tmp is None: node_tmp = etree.SubElement(node, node_name) node.append(node_tmp) node = node_tmp if attributes is not None: for attribute_tmp in attributes.keys(): attribute = "_".join(attribute_tmp.split("_")[:-1]) if attribute != "path": node.set(attribute, attributes[attribute]) if data.size > 0: node.text = ", ".join( map(str, data.flatten()) ) # np.array2string(data, separator=',')[1:-1].replace('\n','') with open(f"{self.filename}{name}.xml", "wb") as f: f.write(etree.tostring(root, pretty_print=True, encoding="utf-8")) f.close() else: np.save(f"{self.filename}{name}.npy", data)
[docs] def is_exist(self, group: str, name: str | None = None) -> bool: base = f"{self.filename}{group}" return any(os.path.exists(base + "." + ext) for ext in SUPPORTED_EXTENSIONS)
[docs] def get_names(self, group: str) -> list[str]: raise NotImplementedError()
[docs] def get_group(self): raise NotImplementedError()
[docs] def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]: attributes = Attribute() if os.path.exists(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}"): file_reader = sitk.ImageFileReader() file_reader.SetFileName(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}") file_reader.ReadImageInformation() attributes["Origin"] = np.asarray(file_reader.GetOrigin()) attributes["Spacing"] = np.asarray(file_reader.GetSpacing()) attributes["Direction"] = np.asarray(file_reader.GetDirection()) for k in file_reader.GetMetaDataKeys(): attributes[k] = file_reader.GetMetaData(k) size = list(file_reader.GetSize()) if len(size) == 3: size = list(reversed(size)) size = [file_reader.GetNumberOfComponents()] + size else: data, attributes = self.file_to_data(group if group is not None else "", name) size = data.shape return size, attributes
[docs] class File: def __init__(self, filename: str, read: bool, file_format: str) -> None: self.filename = filename self.read = read self.file: "Dataset.AbstractFile" | None = None self.file_format = file_format def __enter__(self) -> "Dataset.AbstractFile": if self.file_format == "h5": self.file = Dataset.H5File(self.filename, self.read) else: self.file = Dataset.SitkFile(self.filename + "/", self.read, self.file_format) self.file.__enter__() return self.file def __exit__(self, exc_type, value, traceback): if self.file is not None: self.file.__exit__(exc_type, value, traceback)
def __init__(self, filename: str | Path, file_format: str) -> None: if file_format != "h5" and not str(filename).endswith("/"): filename = f"{filename}/" self.is_directory = str(filename).endswith("/") self.filename = str(filename) self.file_format = file_format def _exists_on_disk(self) -> bool: if os.path.exists(self.filename): return True return self.file_format == "h5" and os.path.exists(f"{self.filename}.h5")
[docs] def write( self, group: str, name: str, data: sitk.Image | sitk.Transform | np.ndarray, attributes: Attribute | None = None, ): if attributes is None: attributes = Attribute() if self.is_directory: if not os.path.exists(self.filename): os.makedirs(self.filename) if self.is_directory: s_group = group.split("/") if len(s_group) > 1: sub_directory = "/".join(s_group[:-1]) name = f"{sub_directory}/{name}" group = s_group[-1] with Dataset.File(f"{self.filename}{name}", False, self.file_format) as file: file.data_to_file(group, data, attributes) else: with Dataset.File(self.filename, False, self.file_format) as file: file.data_to_file(f"{group}/{name}", data, attributes)
[docs] def read_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]: if not self._exists_on_disk(): raise NameError(f"Dataset {self.filename} not found") if self.is_directory: for sub_directory in self._get_sub_directories(groups): group = groups.split("/")[-1] if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"): with Dataset.File( f"{self.filename}{sub_directory}{name}", False, self.file_format, ) as file: result = file.file_to_data("", group) else: with Dataset.File(self.filename, False, self.file_format) as file: result = file.file_to_data(groups, name) return result
[docs] def read_data_slice(self, groups: str, name: str, slices: tuple[slice, ...]) -> tuple[np.ndarray, Attribute]: if not self._exists_on_disk(): raise NameError(f"Dataset {self.filename} not found") if self.is_directory: for sub_directory in self._get_sub_directories(groups): group = groups.split("/")[-1] if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"): with Dataset.File( f"{self.filename}{sub_directory}{name}", True, self.file_format, ) as file: result = file.file_to_data_slice("", group, slices) return result else: with Dataset.File(self.filename, True, self.file_format) as file: return file.file_to_data_slice(groups, name, slices) raise NameError(f"Dataset entry '{groups}/{name}' not found in {self.filename}.")
[docs] def read_data_statistics( self, groups: str, name: str, channels: list[int] | None = None, ) -> dict[str, float]: if not self._exists_on_disk(): raise NameError(f"Dataset {self.filename} not found") if self.is_directory: for sub_directory in self._get_sub_directories(groups): group = groups.split("/")[-1] if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"): with Dataset.File( f"{self.filename}{sub_directory}{name}", True, self.file_format, ) as file: return file.file_to_data_statistics("", group, channels) else: with Dataset.File(self.filename, True, self.file_format) as file: return file.file_to_data_statistics(groups, name, channels) raise NameError(f"Dataset entry '{groups}/{name}' not found in {self.filename}.")
[docs] def read_transform(self, group: str, name: str) -> sitk.Transform: if not self._exists_on_disk(): raise NameError(f"Dataset {self.filename} not found") transform_parameters, attribute = self.read_data(group, name) transforms_type = [v for k, v in attribute.items() if k.endswith(":Transform_0")] transforms = [] for i, transform_type in enumerate(transforms_type): if transform_type == "Euler3DTransform_double_3_3": transform = sitk.Euler3DTransform() if transform_type == "AffineTransform_double_3_3": transform = sitk.AffineTransform(3) if transform_type == "BSplineTransform_double_3_3": transform = sitk.BSplineTransform(3) transform.SetFixedParameters(ast.literal_eval(attribute[f"{i}:FixedParameters"])) transform.SetParameters(tuple(transform_parameters[i])) transforms.append(transform) return sitk.CompositeTransform(transforms) if len(transforms) > 1 else transforms[0]
[docs] def read_image(self, group: str, name: str): data, attribute = self.read_data(group, name) return data_to_image(data, attribute)
[docs] def get_size(self, group: str) -> int: return len(self.get_names(group))
[docs] def is_group_exist(self, group: str) -> bool: return self.get_size(group) > 0
[docs] def is_dataset_exist(self, group: str, name: str) -> bool: return name in self.get_names(group)
def _get_sub_directories(self, groups: str, sub_directory: str = ""): group = groups.split("/")[0] sub_directories = [] if len(groups.split("/")) == 1: sub_directories.append(sub_directory) elif group == "*": for k in os.listdir(f"{self.filename}{sub_directory}"): if not os.path.isfile(f"{self.filename}{sub_directory}{k}"): sub_directories.extend( self._get_sub_directories( "/".join(groups.split("/")[1:]), f"{sub_directory}{k}/", ) ) else: sub_directory = f"{sub_directory}{group}/" if os.path.exists(f"{self.filename}{sub_directory}"): sub_directories.extend(self._get_sub_directories("/".join(groups.split("/")[1:]), sub_directory)) return sub_directories
[docs] def get_names(self, groups: str, index: list[int] | None = None) -> list[str]: names = [] if self.is_directory: for sub_directory in self._get_sub_directories(groups): group = groups.split("/")[-1] if os.path.exists(f"{self.filename}{sub_directory}"): for name in sorted(os.listdir(f"{self.filename}{sub_directory}")): if os.path.isfile(f"{self.filename}{sub_directory}{name}") or self.file_format != "h5": with Dataset.File( f"{self.filename}{sub_directory}{name}", True, self.file_format, ) as file: if file.is_exist(group): names.append(name.replace(".h5", "") if self.file_format == "h5" else name) else: with Dataset.File(self.filename, True, self.file_format) as file: names = file.get_names(groups) return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
[docs] def get_group(self): if self.is_directory: groups_set = set() for root, _, files in os.walk(self.filename): for file in files: path = Path(root, file.split(".")[0]).relative_to(self.filename).as_posix() parts = path.split("/") if len(parts) >= 2: del parts[-2] groups_set.add("/".join(parts)) groups = list(groups_set) else: with Dataset.File(self.filename, True, self.file_format) as dataset_file: groups = dataset_file.get_group() return list(groups)
[docs] def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]: if self.is_directory: for sub_directory in self._get_sub_directories(groups): group = groups.split("/")[-1] if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"): with Dataset.File( f"{self.filename}{sub_directory}{name}", True, self.file_format, ) as file: result = file.get_infos("", group) else: with Dataset.File(self.filename, True, self.file_format) as file: result = file.get_infos(groups, name) return result
[docs] def get_statistics(self, groups: str) -> dict[str, dict[str, dict[str, float | list[float]]]]: names = self.get_names(groups) stats = {} for name in names: data, attr = self.read_data(groups, name) min_, max_ = data.min(), data.max() mean_ = data.mean() std_ = data.std() # Percentiles in ONE call p25, p50, p75 = np.percentile(data, (25, 50, 75)) stats[name] = { "min": float(min_), "max": float(max_), "mean": float(mean_), "std": float(std_), "25pc": float(p25), "50pc": float(p50), "75pc": float(p75), "shape": list(data.shape), "spacing": attr.get_np_array("Spacing").tolist(), } result: dict[str, dict[str, dict[str, Any]]] = {} result["case"] = {} for name, v in stats.items(): for metric_name, value in v.items(): if metric_name not in result["case"]: result["case"][metric_name] = {} result["case"][metric_name][name] = value result["aggregates"] = {} tmp: dict[str, list[float]] = {} for _, v in stats.items(): for metric_name, _ in v.items(): if metric_name not in tmp: tmp[metric_name] = [] tmp[metric_name].append(v[metric_name]) for metric_name, values in tmp.items(): if isinstance(values[0], float): result["aggregates"][metric_name] = { "max": float(np.nanmax(values)) if np.any(~np.isnan(values)) else np.nan, "min": float(np.nanmin(values)) if np.any(~np.isnan(values)) else np.nan, "std": float(np.nanstd(values)) if np.any(~np.isnan(values)) else np.nan, "25pc": float(np.nanpercentile(values, 25)) if np.any(~np.isnan(values)) else np.nan, "50pc": float(np.nanpercentile(values, 50)) if np.any(~np.isnan(values)) else np.nan, "75pc": float(np.nanpercentile(values, 75)) if np.any(~np.isnan(values)) else np.nan, "mean": float(np.nanmean(values)) if np.any(~np.isnan(values)) else np.nan, "count": float(np.count_nonzero(~np.isnan(values))) if np.any(~np.isnan(values)) else np.nan, } else: p25, p50, p75 = np.nanpercentile(values, (25, 50, 75)) result["aggregates"][metric_name] = { "max": np.nanmax(values, axis=0).tolist(), "min": np.nanmin(values, axis=0).tolist(), "std": np.nanstd(values, axis=0).tolist(), "mean": np.nanmean(values, axis=0).tolist(), } return result