# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Reusable network blocks and tensor graph helpers for KonfAI models."""
import ast
import importlib
from collections.abc import Callable
from enum import Enum
import SimpleITK as sitk # noqa: N813
import torch
from konfai.network import network
from konfai.utils.config import config
[docs]
class NormMode(Enum):
"""Enumeration of normalization layers supported by KonfAI blocks."""
NONE = (0,)
BATCH = 1
INSTANCE = 2
GROUP = 3
LAYER = 4
SYNCBATCH = 5
INSTANCE_AFFINE = 6
[docs]
def get_norm(norm_mode: Enum, channels: int, dim: int) -> torch.nn.Module | None:
"""Instantiate the normalization layer matching the requested mode."""
if norm_mode == NormMode.BATCH:
return get_torch_module("BatchNorm", dim=dim)(channels, affine=True, track_running_stats=True)
if norm_mode == NormMode.INSTANCE:
return get_torch_module("InstanceNorm", dim=dim)(channels, affine=False, track_running_stats=False)
if norm_mode == NormMode.INSTANCE_AFFINE:
return get_torch_module("InstanceNorm", dim=dim)(channels, affine=True, track_running_stats=False)
if norm_mode == NormMode.SYNCBATCH:
return torch.nn.SyncBatchNorm(channels, affine=True, track_running_stats=True)
if norm_mode == NormMode.GROUP:
return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
if norm_mode == NormMode.LAYER:
return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
return None
[docs]
class UpsampleMode(Enum):
CONV_TRANSPOSE = (0,)
UPSAMPLE = (1,)
[docs]
class DownsampleMode(Enum):
MAXPOOL = (0,)
AVGPOOL = (1,)
CONV_STRIDE = 2
[docs]
def get_torch_module(name_fonction: str, dim: int | None = None) -> torch.nn.Module:
"""Return a dimensional PyTorch module class such as ``Conv2d`` or ``Conv3d``."""
return getattr(
importlib.import_module("torch.nn"),
f"{name_fonction}" + (f"{dim}d" if dim is not None else ""),
)
[docs]
@config()
class BlockConfig:
"""Configuration object describing one convolutional block stage."""
def __init__(
self,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
bias=True,
activation: str | Callable[[], torch.nn.Module] | None = "ReLU",
norm_mode: str | NormMode | Callable[[int], torch.nn.Module] = "NONE",
) -> None:
self.kernel_size = kernel_size
self.bias = bias
self.stride = stride
self.padding = padding
self.activation = activation
self.norm_mode = norm_mode
self.norm: NormMode | Callable[[int], torch.nn.Module] | None = None
if isinstance(norm_mode, str):
self.norm = NormMode[norm_mode]
else:
self.norm = norm_mode
[docs]
def get_conv(self, in_channels: int, out_channels: int, dim: int) -> torch.nn.Conv3d:
return get_torch_module("Conv", dim=dim)(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
bias=self.bias,
)
[docs]
def get_norm(self, channels: int, dim: int) -> torch.nn.Module:
if self.norm is None:
return None
return get_norm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
[docs]
def get_activation(self) -> torch.nn.Module:
if self.activation is None:
return None
if isinstance(self.activation, str):
return (
get_torch_module(self.activation.split(";")[0])(
*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]
)
if self.activation != "None"
else torch.nn.Identity()
)
return self.activation()
[docs]
class ConvBlock(network.ModuleArgsDict):
"""Sequential convolution, normalization, and activation block."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_configs: list[BlockConfig],
dim: int,
alias: list[list[str]] = [[], [], []],
) -> None:
super().__init__()
for i, block_config in enumerate(block_configs):
self.add_module(
f"Conv_{i}",
block_config.get_conv(in_channels, out_channels, dim),
alias=alias[0],
)
norm = block_config.get_norm(out_channels, dim)
if norm is not None:
self.add_module(f"Norm_{i}", norm, alias=alias[1])
activation = block_config.get_activation()
if activation is not None:
self.add_module(f"Activation_{i}", activation, alias=alias[2])
in_channels = out_channels
[docs]
class ResBlock(network.ModuleArgsDict):
"""Residual block with optional projection on the skip path."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_configs: list[BlockConfig],
dim: int,
alias: list[list[str]] = [[], [], [], [], []],
) -> None:
super().__init__()
for i, block_config in enumerate(block_configs):
self.add_module(
f"Conv_{i}",
block_config.get_conv(in_channels, out_channels, dim),
alias=alias[0],
)
norm = block_config.get_norm(out_channels, dim)
if norm is not None:
self.add_module(f"Norm_{i}", norm, alias=alias[1])
activation = block_config.get_activation()
if activation is not None:
self.add_module(f"Activation_{i}", activation, alias=alias[2])
if in_channels != out_channels:
self.add_module(
"Conv_skip",
get_torch_module("Conv", dim)(
in_channels,
out_channels,
1,
block_config.stride,
bias=block_config.bias,
),
alias=alias[3],
in_branch=[1],
out_branch=[1],
)
self.add_module(
"Norm_skip",
block_config.get_norm(out_channels, dim),
alias=alias[4],
in_branch=[1],
out_branch=[1],
)
in_channels = out_channels
self.add_module("Add", Add(), in_branch=[0, 1])
self.add_module(f"Norm_{i + 1}", torch.nn.ReLU(inplace=True))
[docs]
def downsample(in_channels: int, out_channels: int, downsample_mode: DownsampleMode, dim: int) -> torch.nn.Module:
"""Return the downsampling module matching the requested strategy."""
if downsample_mode == DownsampleMode.MAXPOOL:
return get_torch_module("MaxPool", dim=dim)(2)
if downsample_mode == DownsampleMode.AVGPOOL:
return get_torch_module("AvgPool", dim=dim)(2)
if downsample_mode == DownsampleMode.CONV_STRIDE:
return get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
[docs]
def upsample(
in_channels: int,
out_channels: int,
upsample_mode: UpsampleMode,
dim: int,
kernel_size: int | list[int] = 2,
stride: int | list[int] = 2,
):
"""Return the upsampling module matching the requested strategy."""
if upsample_mode == UpsampleMode.CONV_TRANSPOSE:
return get_torch_module("ConvTranspose", dim=dim)(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
)
else:
if dim == 3:
upsample_method = "trilinear"
if dim == 2:
upsample_method = "bilinear"
if dim == 1:
upsample_method = "linear"
return torch.nn.Upsample(scale_factor=2, mode=upsample_method.lower(), align_corners=False)
[docs]
class Unsqueeze(torch.nn.Module):
def __init__(self, dim: int = 0):
super().__init__()
self.dim = dim
[docs]
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
return torch.unsqueeze(tensor, self.dim)
[docs]
class Permute(torch.nn.Module):
def __init__(self, dims: list[int]):
super().__init__()
self.dims = dims
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.permute(tensor, self.dims)
[docs]
class ToChannels(Permute):
def __init__(self, dim: int):
super().__init__([0, dim + 1, *[i + 1 for i in range(dim)]])
[docs]
class ToFeatures(Permute):
def __init__(self, dim: int):
super().__init__([0, *[i + 2 for i in range(dim)], 1])
[docs]
class Add(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
return torch.sum(torch.stack(tensor), dim=0)
[docs]
class Multiply(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
return torch.mul(*tensor)
[docs]
class Concat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, *tensor: torch.Tensor) -> torch.Tensor:
return torch.cat(tensor, dim=1)
[docs]
class Print(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
print(tensor.shape)
return tensor
[docs]
class Write(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
sitk.WriteImage(sitk.GetImageFromArray(tensor.clone()[0][0].cpu().numpy()), "./Data.mha")
return tensor
[docs]
class Exit(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
raise RuntimeError("The debug Exit block was executed.")
[docs]
class Detach(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.detach()
[docs]
class Negative(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return -tensor
[docs]
class GetShape(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.tensor(tensor.shape)
[docs]
class ArgMax(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
[docs]
class Select(torch.nn.Module):
def __init__(self, slices: list[slice]) -> None:
super().__init__()
self.slices = tuple(slices)
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
result = tensor[self.slices]
for i, s in enumerate(range(len(result.shape))):
if s == 1:
result = result.squeeze(dim=i)
return result
[docs]
class NormalNoise(torch.nn.Module):
def __init__(self, dim: int | None = None) -> None:
super().__init__()
self.dim = dim
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self.dim is not None:
return torch.randn(self.dim).to(tensor.device)
else:
return torch.randn_like(tensor).to(tensor.device)
[docs]
class Const(torch.nn.Module):
def __init__(self, shape: list[int], std: float) -> None:
super().__init__()
self.noise = torch.nn.parameter.Parameter(torch.randn(shape) * std)
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return self.noise.to(tensor.device)
[docs]
class Subset(torch.nn.Module):
def __init__(self, slices: list[slice]):
super().__init__()
self.slices = [slice(None, None), slice(None, None)] + slices
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor[self.slices]
[docs]
class View(torch.nn.Module):
def __init__(self, size: list[int]):
super().__init__()
self.size = size
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.view(self.size)
[docs]
class LatentDistribution(network.ModuleArgsDict):
[docs]
class LatentDistributionLinear(torch.nn.Module):
def __init__(self, shape: list[int], latent_dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(torch.prod(torch.tensor(shape)), latent_dim)
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.unsqueeze(self.linear(tensor), 1)
[docs]
class LatentDistributionDecoder(torch.nn.Module):
def __init__(self, shape: list[int], latent_dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(latent_dim, torch.prod(torch.tensor(shape)))
self.shape = shape
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return self.linear(tensor).view(-1, *[int(i) for i in self.shape])
[docs]
class LatentDistributionZ(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
[docs]
def forward(self, mu: torch.Tensor, log_std: torch.Tensor) -> torch.Tensor:
return torch.exp(log_std / 2) * torch.rand_like(mu) + mu
def __init__(self, shape: list[int], latent_dim: int) -> None:
super().__init__()
self.add_module("Flatten", torch.nn.Flatten(1))
self.add_module(
"mu",
LatentDistribution.LatentDistributionLinear(shape, latent_dim),
out_branch=[1],
)
self.add_module(
"log_std",
LatentDistribution.LatentDistributionLinear(shape, latent_dim),
out_branch=[2],
)
self.add_module(
"z",
LatentDistribution.LatentDistributionZ(),
in_branch=[1, 2],
out_branch=[3],
)
self.add_module("Concat", Concat(), in_branch=[1, 2, 3])
self.add_module(
"DecoderInput",
LatentDistribution.LatentDistributionDecoder(shape, latent_dim),
in_branch=[3],
)
[docs]
class Attention(network.ModuleArgsDict):
def __init__(self, f_g: int, f_l: int, f_int: int, dim: int):
super().__init__()
self.add_module(
"W_x",
get_torch_module("Conv", dim=dim)(in_channels=f_l, out_channels=f_int, kernel_size=1, stride=2, padding=0),
in_branch=[0],
out_branch=[0],
)
self.add_module(
"W_g",
get_torch_module("Conv", dim=dim)(in_channels=f_g, out_channels=f_int, kernel_size=1, stride=1, padding=0),
in_branch=[1],
out_branch=[1],
)
self.add_module("Add", Add(), in_branch=[0, 1])
self.add_module("ReLU", torch.nn.ReLU(inplace=True))
self.add_module(
"Conv",
get_torch_module("Conv", dim=dim)(in_channels=f_int, out_channels=1, kernel_size=1, stride=1, padding=0),
)
self.add_module("Sigmoid", torch.nn.Sigmoid())
self.add_module("Upsample", torch.nn.Upsample(scale_factor=2))
self.add_module("Multiply", Multiply(), in_branch=[2, 0])