# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Schedulers used to modulate metric and loss weights during training."""
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
[docs]
class Scheduler:
"""Base class for scalar schedulers used by KonfAI criteria."""
def __init__(self, start_value: float) -> None:
self.baseValue = float(start_value)
self.it = 0
[docs]
def step(self, it: int):
self.it = it
[docs]
@abstractmethod
def get_value(self) -> float:
pass
[docs]
class Constant(Scheduler):
"""Scheduler returning a constant value for all iterations."""
def __init__(self, value: float = 1):
super().__init__(value)
[docs]
def get_value(self) -> float:
return self.baseValue
[docs]
class CosineAnnealing(Scheduler):
"""Cosine annealing scheduler for criterion weights."""
def __init__(self, start_value: float = 1, eta_min: float = 0.00001, t_max: int = 100):
super().__init__(start_value)
self.eta_min = eta_min
self.t_max = t_max
[docs]
def get_value(self):
return self.eta_min + (self.baseValue - self.eta_min) * (1 + np.cos(self.it * torch.pi / self.t_max)) / 2
[docs]
class Warmup(torch.optim.lr_scheduler.LambdaLR):
"""Learning-rate warmup wrapper compatible with PyTorch optimizers."""
[docs]
@staticmethod
def warmup(warmup_steps: int, step: int) -> float:
return min(1.0, (step + 1) / (warmup_steps + 1))
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int = 10,
last_epoch=-1,
verbose="deprecated",
):
super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
[docs]
class PolyLRScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer,
initial_lr: float,
max_steps: int,
exponent: float = 0.9,
current_step: int | None = None,
):
self.initial_lr = initial_lr
self.max_steps = max_steps
self.exponent = exponent
self.ctr = 0 if current_step is None else current_step
for param_group in optimizer.param_groups:
param_group["lr"] = initial_lr
param_group.setdefault("initial_lr", initial_lr)
super().__init__(optimizer, last_epoch=-1)
[docs]
def step(self, current_step=None):
if current_step is None:
current_step = self.ctr
self.ctr += 1
current_step = min(current_step, self.max_steps)
new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
for param_group in self.optimizer.param_groups:
param_group["lr"] = new_lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
[docs]
def get_last_lr(self):
return self._last_lr