Source code for konfai.metric.measure

# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Criterion and metric implementations used by KonfAI workflows."""

import copy
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from typing import cast

import numpy as np
import torch
import torch.nn.functional as F  # noqa: N812
from huggingface_hub import hf_hub_download
from tqdm import tqdm

from konfai.data.patching import ModelPatch
from konfai.network.blocks import LatentDistribution
from konfai.network.network import ModelLoader, Network
from konfai.utils.config import apply_config
from konfai.utils.dataset import Attribute
from konfai.utils.utils import get_module

models_register = {}


[docs] class Criterion(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__()
[docs] def get_name(self): return self.__class__.__name__
[docs] @abstractmethod def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: raise NotImplementedError()
[docs] class CriterionWithInit(Criterion): accepts_init = True def __init__(self) -> None: super().__init__()
[docs] @abstractmethod def init(self, model: torch.nn.Module, output_group: str, target_group: str) -> str: raise NotImplementedError()
[docs] class CriterionWithAttribute(Criterion): accepts_attributes = True def __init__(self) -> None: super().__init__()
[docs] @abstractmethod def forward( # type: ignore[override] self, output: torch.Tensor, *targets: torch.Tensor, attributes: list[list[Attribute]] ) -> torch.Tensor: raise NotImplementedError()
[docs] class MaskedLoss(Criterion): def __init__( self, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], mode_image_masked: bool, ) -> None: super().__init__() self.loss = loss self.mode_image_masked = mode_image_masked
[docs] @staticmethod def get_mask(targets: list[torch.Tensor]) -> torch.Tensor | None: if len(targets) == 0: return None mask = targets[0] for target in targets[1:]: mask = mask * target return mask
[docs] def forward( self, output: torch.Tensor, *targets: torch.Tensor, ) -> tuple[torch.Tensor, float]: if len(targets) == 0: raise ValueError("MaskedLoss expects at least one target tensor.") target = targets[0] mask = self.get_mask(list(targets[1:])) loss = output.new_tensor(0.0) true_nb = 0 if mask is None: loss_b = self.loss( output.float(), target.to(device=output.device).float(), ) return loss_b, loss_b.detach().item() target = target.to(device=output.device) mask = mask.to(device=output.device) for batch in range(output.shape[0]): mask_b = mask[batch, ...] == 1 if not torch.any(mask_b): continue output_b = output[batch, ...].float() target_b = target[batch, ...].float() if self.mode_image_masked: mask_b = mask_b.to(dtype=output_b.dtype) loss_b = self.loss( output_b * mask_b, target_b * mask_b, ) else: loss_b = self.loss( torch.masked_select(output_b, mask_b), torch.masked_select(target_b, mask_b), ) loss = loss + loss_b true_nb += 1 if true_nb == 0: return loss, np.nan loss = loss / true_nb return loss, loss.detach().item()
[docs] class MSE(MaskedLoss): @staticmethod def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.nn.MSELoss(reduction=reduction)(x, y) def __init__(self, reduction: str = "mean") -> None: super().__init__(partial(MSE._loss, reduction), False)
[docs] class MAE(MaskedLoss): @staticmethod def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.nn.L1Loss(reduction=reduction)(x, y) def __init__(self, reduction: str = "mean") -> None: super().__init__(partial(MAE._loss, reduction), False)
[docs] class ME(MaskedLoss): @staticmethod def _loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return (x - y).mean() def __init__(self) -> None: super().__init__(ME._loss, False)
[docs] class MAESaveMap(MAE): def __init__(self, reduction: str = "mean", dataset: str | None = None, group: str | None = None) -> None: super().__init__(reduction) self.dataset = dataset self.group = group
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor): # type: ignore[override] loss, true_loss = super().forward(output, *targets) if len(targets) == 2: error_map = ( torch.nn.L1Loss(reduction="none")( output.float() * torch.where(targets[1] == 1, 1, 0), targets[0].float() * torch.where(targets[1] == 1, 1, 0), ) .to(output.dtype) .cpu() ) else: error_map = torch.nn.L1Loss(reduction="none")(output.float(), targets[0].float()).to(output.dtype).cpu() return loss, true_loss, error_map
[docs] def get_name(self) -> str: return "MAE"
[docs] class PSNR(MaskedLoss): @staticmethod def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: mse = torch.mean((x - y).pow(2)) psnr = 10 * torch.log10(dynamic_range**2 / mse) return psnr def __init__(self, dynamic_range: float | None = None) -> None: dynamic_range = dynamic_range if dynamic_range else 1024 + 3071 super().__init__(partial(PSNR._loss, dynamic_range), False)
[docs] class SSIM(MaskedLoss): @staticmethod def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: from skimage.metrics import structural_similarity return structural_similarity( x[0][0].detach().cpu().numpy(), y[0][0].cpu().numpy(), data_range=dynamic_range, gradient=False, full=False, ) def __init__(self, dynamic_range: float | None = None) -> None: dynamic_range = dynamic_range if dynamic_range else 1024 + 3000 super().__init__(partial(SSIM._loss, dynamic_range), True)
[docs] class LPIPS(MaskedLoss):
[docs] @staticmethod def normalize(tensor: torch.Tensor) -> torch.Tensor: return (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor)) * 2 - 1
[docs] @staticmethod def preprocessing(tensor: torch.Tensor) -> torch.Tensor: return tensor.repeat((1, 3, 1, 1)).to(0)
@staticmethod def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: dataset_patch = ModelPatch([1, 320, 320]) dataset_patch.load(x.shape[2:]) patch_iterator = dataset_patch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y)) loss = 0 with tqdm( iterable=enumerate(patch_iterator), leave=False, total=dataset_patch.get_size(0), ) as batch_iter: for _, patch_input in batch_iter: real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1]) loss += loss_fn_alex(real, fake).flatten()[0] return loss / dataset_patch.get_size(0) def __init__(self, model: str = "alex") -> None: import lpips super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model).to(0)), True)
[docs] class TRE(Criterion): def __init__(self) -> None: super().__init__()
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor): loss = torch.linalg.norm(output - targets[0], dim=2) return loss.mean(), {f"Landmarks_{i}": v.item() for i, v in enumerate(loss.mean(0))}
[docs] class Dice(Criterion):
[docs] @staticmethod def flatten(tensor: torch.Tensor) -> torch.Tensor: return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
[docs] @staticmethod def dice_per_channel(tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor: tensor = Dice.flatten(tensor) target = Dice.flatten(target) return (2.0 * (tensor * target).sum() + 1e-6) / (tensor.sum() + target.sum() + 1e-6)
@staticmethod def _loss(labels: list[int] | None, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target = F.interpolate(targets[0], output.shape[2:], mode="nearest") result = {} loss = torch.tensor(0, dtype=torch.float32).to(output.device) labels = labels if labels is not None else torch.unique(target) for label in labels: tp = target == label if tp.any().item(): if output.shape[1] > 1: pp = output[:, label].unsqueeze(1) else: pp = output == label loss_tmp = Dice.dice_per_channel(pp.float(), tp.float()) loss += loss_tmp result[label] = loss_tmp.item() else: result[label] = np.nan return 1 - loss / len(labels), result def __init__(self, labels: list[int] | None = None) -> None: super().__init__() self.loss = partial(Dice._loss, labels)
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> tuple[torch.Tensor, float]: mask = MaskedLoss.get_mask(list(targets[1:])) if mask is not None: return self.loss( (output * torch.where(targets[1] == 1, 1, 0)).to(torch.uint8), (targets[0] * torch.where(targets[1] == 1, 1, 0)).to(torch.uint8), ) else: return self.loss(output, targets[0])
[docs] class DiceSaveMap(Dice): def __init__(self, labels: list[int] | None = None, dataset: str | None = None, group: str | None = None) -> None: super().__init__(labels) self.dataset = dataset self.group = group
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor): # type: ignore[override] loss, true_loss = super().forward(output, *targets) if len(targets) == 2: error_map = ( torch.nn.L1Loss(reduction="none")( output * torch.where(targets[1] == 1, 1, 0), targets[0] * torch.where(targets[1] == 1, 1, 0) ) .to(torch.uint8) .cpu() ) else: error_map = torch.nn.L1Loss(reduction="none")(output, targets[0]).to(torch.uint8).cpu() return loss, true_loss, error_map
[docs] def get_name(self) -> str: return "Dice"
[docs] class GradientImages(Criterion): def __init__(self): super().__init__() @staticmethod def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: dx = image[:, :, 1:, :] - image[:, :, :-1, :] dy = image[:, :, :, 1:] - image[:, :, :, :-1] return dx, dy @staticmethod def _image_gradient_3d( image: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :] dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :] dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1] return dx, dy, dz
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target_0 = targets[0] if len(output.shape) == 5: dx, dy, dz = GradientImages._image_gradient_3d(output) if target_0 is not None: dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient_3d(target_0) dx -= dx_tmp dy -= dy_tmp dz -= dz_tmp return dx.norm() + dy.norm() + dz.norm() else: dx, dy = GradientImages._image_gradient_2d(output) if target_0 is not None: dx_tmp, dy_tmp = GradientImages._image_gradient_2d(target_0) dx -= dx_tmp dy -= dy_tmp return dx.norm() + dy.norm()
[docs] class BCE(Criterion): def __init__(self, target: float = 0) -> None: super().__init__() self.loss = torch.nn.BCEWithLogitsLoss() self.register_buffer("target", torch.tensor(target).type(torch.float32))
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target = self._buffers["target"] return self.loss(output, target.to(output.device).expand_as(output))
[docs] class PatchGanLoss(Criterion): def __init__(self, target: float = 0) -> None: super().__init__() self.loss = torch.nn.MSELoss() self.register_buffer("target", torch.tensor(target).type(torch.float32))
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target = self._buffers["target"] return self.loss(output, (torch.ones_like(output) * target).to(output.device))
[docs] class WGP(Criterion): def __init__(self) -> None: super().__init__()
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: return torch.mean((output - 1) ** 2)
[docs] class Gram(Criterion):
[docs] @staticmethod def compute_gram(tensor: torch.Tensor): (b, ch, w) = tensor.size() with torch.amp.autocast("cuda", enabled=False): return tensor.bmm(tensor.transpose(1, 2)).div(ch * w)
def __init__(self) -> None: super().__init__() self.loss = torch.nn.L1Loss(reduction="sum")
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target = targets[0] if len(output.shape) > 3: output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:]))) if len(target.shape) > 3: target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:]))) return self.loss(Gram.compute_gram(output), Gram.compute_gram(target))
[docs] class PerceptualLoss(Criterion):
[docs] class Module: def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch:nn:L1Loss": 1}) -> None: self.losses = losses self.konfai_args = os.environ["KONFAI_CONFIG_PATH"] if "KONFAI_CONFIG_PATH" in os.environ else ""
[docs] def get_loss(self) -> dict[torch.nn.Module, float]: result: dict[torch.nn.Module, float] = {} for loss, loss_value in self.losses.items(): module, name = get_module(loss, "konfai.metric.measure") result[apply_config(self.konfai_args)(getattr(module, name))()] = loss_value return result
def __init__( self, model_loader: ModelLoader = ModelLoader(), path_model: str = "name", modules: dict[str, Module] = { "UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch:nn:L1Loss": 1}) }, shape: list[int] = [128, 128, 128], ) -> None: super().__init__() self.path_model = path_model if self.path_model not in models_register: self.model = model_loader.get_model( train=False, konfai_args=os.environ["KONFAI_CONFIG_PATH"].split("PerceptualLoss")[0] + "PerceptualLoss.Model", konfai_without=[ "optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputs_criterions", "drop_p", ], ) if path_model.startswith("https"): state_dict = torch.hub.load_state_dict_from_url(path_model) state_dict = {"Model": {self.model.get_name(): state_dict["model"]}} else: state_dict = torch.load(path_model, weights_only=True) self.model.load(state_dict) models_register[self.path_model] = self.model else: self.model = models_register[self.path_model] self.shape = shape self.mode = "trilinear" if len(shape) == 3 else "bilinear" self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {} for name, losses in modules.items(): self.modules_loss[name.replace(":", ".")] = losses.get_loss() self.model.eval() self.model.requires_grad_(False) self.models: dict[int, torch.nn.Module] = {}
[docs] def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor: # if not all([tensor.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]): # tensor = F.interpolate(tensor, mode=self.mode, # size=tuple(self.shape), align_corners=False).type(torch.float32) # if tensor.shape[1] != self.model.in_channels: # tensor = tensor.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))])) # tensor = (tensor - torch.min(tensor))/(torch.max(tensor)-torch.min(tensor)) # tensor = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor) return tensor
def _compute(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) output_preprocessing = self.preprocessing(output) targets_preprocessing = [self.preprocessing(target) for target in targets] for zipped_output in zip([output_preprocessing], *[[target] for target in targets_preprocessing]): output = zipped_output[0] targets = zipped_output[1:] for zipped_layers in list( zip( self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()), *[ self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy()) for target in targets ], ) ): output_layer = zipped_layers[0][1].view( zipped_layers[0][1].shape[0], zipped_layers[0][1].shape[1], int(np.prod(zipped_layers[0][1].shape[2:])), ) for (loss_function, loss_value), target_layer in zip( self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:] ): target_layer = target_layer[1].view( target_layer[1].shape[0], target_layer[1].shape[1], int(np.prod(target_layer[1].shape[2:])), ) loss = ( loss + loss_value * loss_function(output_layer.float(), target_layer.float()) / output_layer.shape[0] ) return loss
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: if output.device.index not in self.models: del os.environ["device"] self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval() loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) if len(output.shape) == 5 and len(self.shape) == 2: for i in range(output.shape[2]): loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets]) / output.shape[2] else: loss = self._compute(output, targets) return loss.to(output)
[docs] class KLDivergence(CriterionWithInit): def __init__(self, shape: list[int], dim: int = 100, mu: float = 0, std: float = 1) -> None: super().__init__() self.latent_dim = dim self.mu = torch.Tensor([mu]) self.std = torch.Tensor([std]) self.modelDim = 3 self.shape = shape self.loss = torch.nn.KLDivLoss()
[docs] def init(self, model: Network, output_group: str, target_group: str) -> str: model._compute_channels_trace(model, model.in_channels, None, None) last_module = model for name in output_group.split(".")[:-1]: last_module = last_module[name] modules = last_module._modules.copy() last_module._modules.clear() for name, value in modules.items(): last_module._modules[name] = value if name == output_group.split(".")[-1]: last_module.add_module( "LatentDistribution", LatentDistribution(shape=self.shape, latent_dim=self.latent_dim), ) return ".".join(output_group.split(".")[:-1]) + ".LatentDistribution.Concat"
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: mu = output[:, 0, :] log_std = output[:, 1, :] return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim=1), dim=0)
[docs] class Accuracy(Criterion): def __init__(self) -> None: super().__init__() self.n: int = 0 self.corrects = torch.zeros(1)
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target_0 = targets[0] self.n += output.shape[0] self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu() return self.corrects / self.n
[docs] class TripletLoss(Criterion): def __init__(self) -> None: super().__init__() self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: return self.triplet_loss(output[0], output[1], output[2])
[docs] class L1LossRepresentation(Criterion): def __init__(self) -> None: super().__init__() self.loss = torch.nn.L1Loss() def _variance(self, features: torch.Tensor) -> torch.Tensor: return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: return self.loss(output[0], output[1]) + self._variance(output[0]) + self._variance(output[1])
[docs] class FocalLoss(Criterion): def __init__( self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean", ): super().__init__() raw_alpha = torch.tensor(alpha, dtype=torch.float32) self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha) self.gamma = gamma self.reduction = reduction
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: target = F.interpolate(targets[0], output.shape[2:], mode="nearest").long() logpt = F.log_softmax(output, dim=1) pt = torch.exp(logpt) logpt = logpt.gather(1, target) pt = pt.gather(1, target) at = self.alpha.to(target.device)[target].unsqueeze(1) loss = -at * ((1 - pt) ** self.gamma) * logpt if self.reduction == "mean": return loss.mean() elif self.reduction == "sum": return loss.sum() return loss
[docs] class FID(Criterion):
[docs] class InceptionV3(torch.nn.Module): def __init__(self) -> None: super().__init__() from torchvision.models import Inception_V3_Weights, inception_v3 self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False) self.model.fc = torch.nn.Identity() self.model.eval()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x)
def __init__(self) -> None: super().__init__() self.inception_model = FID.InceptionV3().cuda()
[docs] @staticmethod def preprocess_images(image: torch.Tensor) -> torch.Tensor: return F.normalize( F.resize(image, (299, 299)).repeat((1, 3, 1, 1)), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ).cuda()
[docs] @staticmethod def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray: with torch.no_grad(): features = model(images).cpu().numpy() return features
[docs] @staticmethod def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float: mu1 = np.mean(real_features, axis=0) sigma1 = np.cov(real_features, rowvar=False) mu2 = np.mean(generated_features, axis=0) sigma2 = np.cov(generated_features, rowvar=False) diff = mu1 - mu2 from scipy import linalg covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if np.iscomplexobj(covmean): covmean = covmean.real return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3])) generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3])) real_features = FID.get_features(real_images, self.inception_model) generated_features = FID.get_features(generated_images, self.inception_model) return FID.calculate_fid(real_features, generated_features)
[docs] class MutualInformationLoss(torch.nn.Module): def __init__( self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7, ) -> None: super().__init__() bin_centers = torch.linspace(0.0, 1.0, num_bins) sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio self.num_bins = num_bins self.preterm = 1 / (2 * sigma**2) self.bin_centers = bin_centers[None, None, ...] self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr)
[docs] def parzen_windowing( self, pred: torch.Tensor, target: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: pred_weight, pred_probability = self.parzen_windowing_gaussian(pred) target_weight, target_probability = self.parzen_windowing_gaussian(target) return pred_weight, pred_probability, target_weight, target_probability
[docs] def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: img = torch.clamp(img, 0, 1) img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1) weight = torch.exp( -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2 ) # (batch, num_sample, num_bin) weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin) probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin) return weight, probability
[docs] def forward(self, pred: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: wa, pa, wb, pb = self.parzen_windowing(pred, targets[0]) # (batch, num_sample, num_bin), (batch, 1, num_bin) pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins) papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins) mi = torch.sum( pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2), ) # (batch) return torch.mean(mi).neg() # average over the batch and channel ndims
[docs] class CrossEntropyLoss(Criterion): def __init__(self, weight: list[float] | None = None, reduction: str = "mean") -> None: super().__init__() self.loss = torch.nn.CrossEntropyLoss(weight=torch.tensor(weight) if weight else None, reduction=reduction)
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: return self.loss(output, targets[0].squeeze(1))
[docs] class IMPACTReg(CriterionWithAttribute):
[docs] class Weights: def __init__(self, weights: list[float] = [0, 1]) -> None: self.weights = weights
def __init__( self, name: str = "Reg", model_name: str = "TS/M291.pt", shape: list[int] = [0, 0], in_channels: int = 3, loss: str = "torch:nn:L1Loss", weights: list[float] = [0, 1], ) -> None: super().__init__() if model_name is None: return self.name = name self.in_channels = in_channels self.nb_layer = len(weights) module, name = get_module(loss, "konfai.metric.measure") self.loss = apply_config(os.environ["KONFAI_CONFIG_PATH"])(getattr(module, name))() self.weights = weights self.model_path = hf_hub_download( repo_id="VBoussot/impact-torchscript-models", filename=model_name, repo_type="model", revision=None ) # nosec B615 self.model: torch.nn.Module = torch.jit.load(self.model_path, map_location=torch.device("cpu")) # nosec B614 self.dim = len(shape) self.shape = shape if all(s > 0 for s in shape) else None self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {} dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim))).to(0) try: out = self.model.to(0)(dummy_input, torch.tensor([self.nb_layer])) if not isinstance(out, (list, tuple)): raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.") if len(weights) != len(out): raise ValueError( f"Loss '{loss}': mismatch between the number of weights " f"({len(weights)}) and the number of model outputs " f"({len(out)}). Each output must have a corresponding weight." ) except Exception as e: msg = ( f"[Model Sanity Check Failed]\n" f"Input shape attempted: {dummy_input.shape}\n" f"Error: {type(e).__name__}: {e}" ) raise RuntimeError(msg) from e self.model = None
[docs] def preprocessing(self, tensor: torch.Tensor, attribute: list[Attribute]) -> list[torch.Tensor]: if tensor.shape[1] != self.in_channels: tensor = tensor.repeat(tuple([1, 3] + [1 for _ in range(self.dim)])) return [ tensor, torch.tensor([self.nb_layer]), torch.tensor( [ [ float(attr["ImageMin"]), float(attr["ImageMean"]), float(attr["ImageMax"]), float(attr["ImageStd"]), ] for attr in attribute ] ), ]
[docs] def get_name(self): return self.name
def _compute( self, output: torch.Tensor, output_attributes: list[Attribute], target: torch.Tensor, target_attributes: list[Attribute], mask: torch.Tensor | None, ) -> torch.Tensor: loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) output = self.preprocessing(output, output_attributes) target = self.preprocessing(target, target_attributes) true_nb = 0 if self.shape is not None: model_patch = ModelPatch(self.shape) model_patch.load(output[0].shape[2:]) for index in range(model_patch.get_size(0)): mask_patch = model_patch.get_data(mask, index, 0, True) if mask is not None else None if mask is None or (mask_patch is not None and torch.any(mask_patch == 1)): for i, zipped_output in enumerate( zip( self.model(model_patch.get_data(output[0], index, 0, True), output[1], output[2]), self.model(model_patch.get_data(target[0], index, 0, True), target[1], target[2]), ) ): if self.weights[i] == 0: continue output_feature = zipped_output[0] target_feature = zipped_output[1] if mask is not None: if mask_patch is None: raise RuntimeError("LPIPS mask patch is unexpectedly missing.") mask_patch_tensor = cast(torch.Tensor, mask_patch) mask_index_resampled = ( torch.nn.functional.interpolate( mask_patch_tensor.float(), mode="nearest", size=tuple(output_feature.shape[2:]) ).repeat((1, output_feature.shape[1], *([1] * self.dim))) == 1 ) if torch.any(mask_index_resampled): loss_value = self.weights[i] * self.loss( torch.masked_select(output_feature, mask_index_resampled).float(), torch.masked_select(target_feature, mask_index_resampled).float(), ) if loss_value.isnan(): continue loss += loss_value else: continue else: loss_value = self.weights[i] * self.loss(output_feature.float(), target_feature.float()) if loss_value.isnan(): continue loss += loss_value true_nb += 1 else: if mask is None or torch.any(mask == 1): for i, zipped_output in enumerate(zip(self.model(*output), self.model(*target))): if self.weights[i] == 0: continue output_feature = zipped_output[0] target_feature = zipped_output[1] if mask is not None: mask_index_resampled = ( torch.nn.functional.interpolate( mask.float(), mode="nearest", size=tuple(output_feature.shape[2:]) ).repeat((1, output_feature.shape[1], *([1] * self.dim))) == 1 ) if torch.any(mask_index_resampled): loss += self.weights[i] * self.loss( torch.masked_select(output_feature, mask_index_resampled).float(), torch.masked_select(target_feature, mask_index_resampled).float(), ) else: true_nb -= 1 else: loss += self.weights[i] * self.loss(output_feature.float(), target_feature.float()) true_nb += 1 return loss, true_nb
[docs] def forward( # type: ignore[override] self, output: torch.Tensor, *targets: torch.Tensor, attributes: list[list[Attribute]] ) -> tuple[torch.Tensor, float]: mask = targets[-1] if targets[-1].dtype == torch.uint8 else None if self.model is None: self.model = torch.jit.load(self.model_path) # nosec B614 self.model.to(output.device) self.model.eval() loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) if len(output.shape) == 5 and self.dim == 2: true_nb = 0 for i in range(output.shape[2]): loss_tmp, true_nb_tmp = self._compute( output[:, :, i, ...], attributes[0], targets[0][:, :, i, ...], attributes[1], mask[:, :, i, ...] if mask is not None else None, ) loss += loss_tmp true_nb += true_nb_tmp else: loss, true_nb = self._compute(output, attributes[0], targets[0], attributes[1], mask) return loss / true_nb, np.nan if true_nb == 0 else loss.item() / true_nb
[docs] class IMPACTSynth(CriterionWithAttribute):
[docs] class Weights: def __init__(self, weights: list[float] = [0, 1]) -> None: self.weights = weights
def _test_model(self, model_path_content: str, in_channels: int, shape: list[int], weights: list[float]): model: torch.nn.Module = torch.jit.load(model_path_content, map_location=torch.device("cpu")) # nosec B614 dummy_input = torch.zeros((1, in_channels, *shape)).to(0) try: out = model.to(0)(dummy_input, torch.tensor([len(weights)])) if not isinstance(out, (list, tuple)): raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.") if len(weights) != len(out): raise ValueError( f"Loss '{model_path_content}': mismatch between the number of weights " f"({len(weights)}) and the number of model outputs " f"({len(out)}). Each output must have a corresponding weight." ) except Exception as e: msg = ( f"[Model Sanity Check Failed]\n" f"Input shape attempted: {dummy_input.shape}\n" f"Error: {type(e).__name__}: {e}" ) raise RuntimeError(msg) from e def __init__( self, model_content_name: str, model_style_name: str, shape_content: list[int] = [0, 0], shape_style: list[int] = [0, 0], in_channels_content: int = 1, in_channels_style: int = 1, weights_criterion_content: list[float] = [0, 0, 1], weights_criterion_style: list[float] = [1, 1, 1], ) -> None: super().__init__() if model_content_name is None: return self.in_channels_content = in_channels_content self.in_channels_style = in_channels_style self.weights_criterion_content = weights_criterion_content self.weights_criterion_style = weights_criterion_style self.loss_content_function = torch.nn.MSELoss() self.loss_style_function = Gram() self.model_path_content = hf_hub_download( repo_id="VBoussot/impact-torchscript-models", filename=model_content_name, repo_type="model", revision=None ) # nosec B615 self.model_path_style = hf_hub_download( repo_id="VBoussot/impact-torchscript-models", filename=model_style_name, repo_type="model", revision=None ) # nosec B615 self.shape_content = shape_content if all(s > 0 for s in shape_content) else None self.shape_style = shape_style if all(s > 0 for s in shape_style) else None self._test_model( self.model_path_content, self.in_channels_content, self.shape_content if self.shape_content else [224] * len(shape_content), weights_criterion_content, ) self._test_model( self.model_path_style, self.in_channels_style, self.shape_style if self.shape_style else [224] * len(shape_style), weights_criterion_style, ) self.model_content: torch.nn.Module | None = None self.model_style: torch.nn.Module | None = None def _preprocessing( self, tensor: torch.Tensor, in_channels: int, nb_layer: int, attribute: list[Attribute] ) -> list[torch.Tensor]: if tensor.shape[1] != in_channels: tensor = tensor.repeat(tuple([1, in_channels] + [1 for _ in range(tensor.dim() - 2)])) if "Mean" in attribute[0] and "Std" in attribute[0]: mean_value = torch.tensor([float(a["Mean"]) for a in attribute], device=tensor.device).view( -1, *([1] * (tensor.dim() - 1)) ) std_value = torch.tensor([float(a["Std"]) for a in attribute], device=tensor.device).view( -1, *([1] * (tensor.dim() - 1)) ) tensor = tensor * std_value + mean_value elif "Min" in attribute[0] and "Max" in attribute[0]: min_value = torch.tensor([float(a["Min"]) for a in attribute], device=tensor.device).view( -1, *([1] * (tensor.dim() - 1)) ) max_value = torch.tensor([float(a["Max"]) for a in attribute], device=tensor.device).view( -1, *([1] * (tensor.dim() - 1)) ) tensor = (tensor + 1) / 2 * (max_value - min_value) + min_value return [ tensor, torch.tensor([nb_layer]), torch.tensor( [ [ float(attr["ImageMin"]), float(attr["ImageMean"]), float(attr["ImageMax"]), float(attr["ImageStd"]), ] for attr in attribute ] ), ] def _loss_compute( self, tensor: list[torch.Tensor], target: list[torch.Tensor], weights: list[float], shape: list[int] | None, mask: torch.Tensor | None, model: torch.nn.Module, loss_function: torch.nn.Module, ) -> tuple[torch.Tensor, int]: loss = torch.zeros((1), requires_grad=True).to(tensor[0].device, non_blocking=False).type(torch.float32) true_nb = 0 if shape is not None: model_patch = ModelPatch(shape) model_patch.load(tensor[0].shape[2:]) for index in range(model_patch.get_size(0)): mask_patch = model_patch.get_data(mask, index, 0, True) if mask is not None else None if mask is None or (mask_patch is not None and torch.any(mask_patch == 1)): for output_feature, target_feature, weight in zip( model(model_patch.get_data(tensor[0], index, 0, True), tensor[1], tensor[2]), model(model_patch.get_data(target[0], index, 0, True), target[1], target[2]), weights, ): if weight == 0: continue if mask is not None: if mask_patch is None: raise RuntimeError("IMPACTSynth mask patch is unexpectedly missing.") mask_patch_tensor = cast(torch.Tensor, mask_patch) mask_index_resampled = ( torch.nn.functional.interpolate( mask_patch_tensor.float(), mode="nearest", size=tuple(output_feature.shape[2:]) ).repeat((1, output_feature.shape[1], *([1] * (mask_patch_tensor.dim() - 2)))) == 1 ) if torch.any(mask_index_resampled): loss += weight * loss_function( torch.masked_select(output_feature, mask_index_resampled).float(), torch.masked_select(target_feature, mask_index_resampled).float(), ) else: true_nb -= 1 else: loss += weight * loss_function(output_feature.float(), target_feature.float()) true_nb += 1 else: if mask is None or torch.any(mask == 1): for output_feature, target_feature, weight in zip(model(*tensor), model(*target), weights): if weight == 0: continue if mask is not None: mask_index_resampled = ( torch.nn.functional.interpolate( mask.float(), mode="nearest", size=tuple(output_feature.shape[2:]) ).repeat((1, output_feature.shape[1], *([1] * (mask.dim() - 2)))) == 1 ) if torch.any(mask_index_resampled): loss += weight * loss_function( torch.masked_select(output_feature, mask_index_resampled).float(), torch.masked_select(target_feature, mask_index_resampled).float(), ) else: true_nb -= 1 else: loss += weight * loss_function(output_feature.float(), target_feature.float()) true_nb += 1 return loss, true_nb
[docs] def forward( # type: ignore[override] self, output: torch.Tensor, *targets: torch.Tensor, attributes: list[list[Attribute]] ) -> tuple[torch.Tensor, float]: if len(targets) < 2: raise ValueError("At least two target tensors are required.") if self.model_content is None: self.model_content = torch.jit.load(self.model_path_content, map_location=torch.device("cpu")) # nosec B614 self.model_content.eval() if self.model_style is None: self.model_style = torch.jit.load(self.model_path_style, map_location=torch.device("cpu")) # nosec B614 self.model_style.eval() model_content = self.model_content model_style = self.model_style if model_content is None or model_style is None: raise RuntimeError("IMPACTSynth models were not initialized correctly.") model_content.to(output.device) model_style.to(output.device) output_content = self._preprocessing( output, self.in_channels_content, len(self.weights_criterion_content), attributes[0] ) output_style = self._preprocessing( output, self.in_channels_style, len(self.weights_criterion_style), attributes[2] ) target_content = self._preprocessing( targets[0], self.in_channels_content, len(self.weights_criterion_content), attributes[1] ) target_style = self._preprocessing( targets[1], self.in_channels_style, len(self.weights_criterion_style), attributes[2] ) mask = targets[2] if len(targets) == 3 and targets[2].dtype == torch.uint8 else None loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) if len(output.shape) == 5 and len(self.weights_criterion_content) == 2: true_nb = 0 for i in range(output.shape[2]): loss_content, true_nb_content = self._loss_compute( self._preprocessing( output[:, :, i, ...], self.in_channels_content, len(self.weights_criterion_content), attributes[0], ), target_content, self.weights_criterion_content, self.shape_content, mask[:, :, i, ...] if mask is not None else None, model=model_content, loss_function=self.loss_content_function, ) loss += loss_content true_nb += true_nb_content else: loss_content, true_nb_content = self._loss_compute( output_content, target_content, self.weights_criterion_content, self.shape_content, mask if mask is not None else None, model=model_content, loss_function=self.loss_content_function, ) loss = loss_content true_nb = true_nb_content if len(output.shape) == 5 and len(self.weights_criterion_style) == 2: true_nb = 0 for i in range(output.shape[2]): loss_style, true_nb_style = self._loss_compute( self._preprocessing( output[:, :, i, ...], self.in_channels_style, len(self.weights_criterion_style), attributes[2] ), target_style, self.weights_criterion_style, self.shape_style, mask[:, :, i, ...] if mask is not None else None, model=model_style, loss_function=self.loss_style_function, ) loss += loss_style true_nb += true_nb_style else: loss_style, true_nb_style = self._loss_compute( output_style, target_style, self.weights_criterion_style, self.shape_style, mask if mask is not None else None, model=model_style, loss_function=self.loss_style_function, ) loss += loss_style true_nb += true_nb_style return loss / true_nb, np.nan if true_nb == 0 else loss.item() / true_nb
[docs] class SAM_Perceptual(CriterionWithAttribute): # noqa: N801 def __init__(self) -> None: super().__init__() self.model: torch.nn.Module | None = None self.loss = torch.nn.L1Loss() self.model_path = hf_hub_download( repo_id="VBoussot/ImpactSynth", filename="SAM2.1_Small.pt", repo_type="model", revision=None ) # nosec B615
[docs] def preprocessing(self, tensor: torch.Tensor, attribute: list[Attribute]) -> list[torch.Tensor]: tensor = tensor.repeat(1, 3, 1, 1) return [ tensor, torch.tensor([4]), torch.tensor( [ [ float(attr["ImageMin"]), float(attr["ImageMean"]), float(attr["ImageMax"]), float(attr["ImageStd"]), ] for attr in attribute ] ), ]
def _compute( self, output: torch.Tensor, target: torch.Tensor, target_attributes: list[Attribute], mask: torch.Tensor | None ) -> torch.Tensor: loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) model = self.model if model is None: raise RuntimeError("SAM perceptual model is not initialized.") output = self.preprocessing(output, target_attributes) target = self.preprocessing(target, target_attributes) true_nb = 0 model_patch = ModelPatch([512, 512]) model_patch.load(output[0].shape[2:]) for index in range(model_patch.get_size(0)): mask_patch = model_patch.get_data(mask, index, 0, True) if mask is not None else None if mask is None or (mask_patch is not None and torch.any(mask_patch == 1)): for zipped_output in zip( model(model_patch.get_data(output[0], index, 0, True), output[1], output[2]), model(model_patch.get_data(target[0], index, 0, True), target[1], target[2]), ): output_feature = zipped_output[0] target_feature = zipped_output[1] if mask_patch is not None: mask_patch_tensor = cast(torch.Tensor, mask_patch) mask_index_resampled = ( torch.nn.functional.interpolate( mask_patch_tensor.float(), mode="nearest", size=tuple(output_feature.shape[2:]) ).repeat((1, output_feature.shape[1], 1, 1)) == 1 ) if torch.any(mask_index_resampled): loss += self.loss( torch.masked_select(output_feature, mask_index_resampled).float(), torch.masked_select(target_feature, mask_index_resampled).float(), ) else: continue else: loss += self.loss(output_feature.float(), target_feature.float()) true_nb += 1 return loss, true_nb
[docs] def forward( # type: ignore[override] self, output: torch.Tensor, *targets: torch.Tensor, attributes: list[list[Attribute]] ) -> tuple[torch.Tensor, float]: mask = targets[-1] if targets[-1].dtype == torch.uint8 else None if self.model is None: self.model = torch.jit.load(self.model_path, map_location=torch.device("cpu")) # nosec B614 model = self.model if model is None: raise RuntimeError("SAM perceptual model failed to load.") model.eval() model.to(output.device) loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32) if len(output.shape) == 5: true_nb = 0 for i in range(output.shape[2]): loss_tmp, true_nb_tmp = self._compute( output[:, :, i, ...], targets[0][:, :, i, ...], attributes[1], mask[:, :, i, ...] if mask is not None else None, ) loss += loss_tmp true_nb += true_nb_tmp else: loss, true_nb = self._compute(output, targets[0], attributes[1], mask) return loss / true_nb, np.nan if true_nb == 0 else loss.item() / true_nb
[docs] class Variance(Criterion): def __init__(self, name: str = "Variance") -> None: super().__init__() self.name = name
[docs] def get_name(self): return self.name
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: return output.float().var(1).mean(), output.float().var(1).mean().item()
[docs] class Mean(Criterion): def __init__(self, name: str = "Mean") -> None: super().__init__() self.name = name
[docs] def get_name(self): return self.name
[docs] def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor: loss = output.float().mean() return loss, loss.item()