# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Configuration helpers that map YAML trees to KonfAI Python objects."""
import collections
import inspect
import os
import types
import typing
from collections.abc import Sequence
from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, Union, get_args, get_origin
import ruamel.yaml
import torch
from konfai import config_file
from konfai.utils.errors import ConfigError
yaml = ruamel.yaml.YAML()
[docs]
class Config:
"""
Context manager for reading and updating a subtree of the active YAML
config.
Parameters
----------
key : str
Dot-separated path pointing to the configuration subtree to inspect or
materialize.
"""
def __init__(self, key: str) -> None:
self.filename = Path(os.environ["KONFAI_config_file"])
self.keys = key.split(".")
def __enter__(self):
if not self.filename.exists():
mode = os.environ.get("KONFAI_CONFIG_MODE", "Done")
if mode in {"default", "interactive", "Import"}:
self.filename.parent.mkdir(parents=True, exist_ok=True)
self.filename.touch()
else:
raise ConfigError(
f"Config file '{self.filename}' does not exist.",
("Create it first or enable a creation mode with " "KONFAI_CONFIG_MODE=default."),
)
self.yml = open(self.filename, encoding="utf-8")
self.data = yaml.load(self.yml)
if self.data is None:
self.data = {}
self.config = self.data
for key in self.keys:
if self.config is None or key not in self.config:
self.config = {key: {}}
self.config = self.config[key]
return self
[docs]
def create_dictionary(self, data, keys, i) -> dict:
if keys[i] not in data:
data = {keys[i]: data}
if i == 0:
return data
else:
i -= 1
return self.create_dictionary(data, keys, i)
[docs]
def merge(self, dict1, dict2) -> dict:
result = deepcopy(dict1)
for key, value in dict2.items():
if isinstance(value, collections.abc.Mapping):
result[key] = self.merge(result.get(key, {}), value)
else:
if dict2[key] is not None:
result[key] = deepcopy(dict2[key])
return result
def __exit__(self, exc_type, value, traceback) -> None:
self.yml.close()
if os.environ["KONFAI_CONFIG_MODE"] == "remove":
if os.path.exists(config_file()):
os.remove(config_file())
return
with open(self.filename) as yml:
data = yaml.load(yml)
if data is None:
data = {}
with open(self.filename, "w") as yml:
# Only the currently visited subtree is rewritten; the recursive
# merge preserves the rest of the YAML file untouched.
yaml.dump(
self.merge(
data,
self.create_dictionary(
self.config,
self.keys,
len(self.keys) - 1,
),
),
yml,
)
@staticmethod
def _get_input(name: str, default: str) -> str:
try:
options = ",".join(default.split(":")[1:]) if ":" in default else ""
return input(f"{name} [{options}]: ")
except (EOFError, KeyboardInterrupt):
# Interactive editing is optional; when stdin is unavailable we
# degrade to default materialization instead of aborting the run.
os.environ["KONFAI_CONFIG_MODE"] = "default"
return default.split("|")[1] if len(default.split("|")) > 1 else default
@staticmethod
def _get_input_default(
name: str,
default: str | None,
is_list: bool = False,
) -> list[str | None] | str | None:
# ``default|value`` is KonfAI's marker for "materialize this default if
# the user/config did not provide a concrete value".
if isinstance(default, str) and (
default == "default" or (len(default.split("|")) > 1 and default.split("|")[0] == "default")
):
if os.environ["KONFAI_CONFIG_MODE"] == "interactive":
if is_list:
list_tmp: list[str | None] = []
key_tmp = "OK"
while key_tmp != "!" and key_tmp != " " and os.environ["KONFAI_CONFIG_MODE"] == "interactive":
key_tmp = Config._get_input(name, default)
if key_tmp != "!" and key_tmp != " ":
if key_tmp == "":
key_tmp = default.split("|")[1] if len(default.split("|")) > 1 else default
list_tmp.append(key_tmp)
return list_tmp
else:
value = Config._get_input(name, default)
if value == "":
return default.split("|")[1] if len(default.split("|")) > 1 else default
else:
return value
else:
default = default.split("|")[1] if len(default.split("|")) > 1 else default
return [default] if is_list else default
[docs]
def get_value(self, name, default) -> object:
if not isinstance(self.config, collections.abc.MutableMapping):
return None
if name in self.config and self.config[name] is not None:
value = self.config[name]
if value is None:
value = default
value_config = value
else:
value = Config._get_input_default(
name,
default if default != inspect._empty else None,
)
value_config = value
if isinstance(value_config, tuple):
value_config = list(value)
if isinstance(value_config, list):
list_tmp = []
for key in value_config:
res = Config._get_input_default(name, key, is_list=True)
if isinstance(res, list):
list_tmp.extend(res)
else:
list_tmp.append(str(res))
value = list_tmp
value_config = list_tmp
if isinstance(value, dict):
key_tmp = []
value_config = {}
dict_value = {}
for key in value:
res = Config._get_input_default(name, key, is_list=True)
if isinstance(res, list):
key_tmp.extend(res)
else:
key_tmp.append(str(res))
for key in key_tmp:
if key in value:
value_tmp = value[key]
else:
value_tmp = next(v for k, v in value.items() if "default" in k)
value_config[key] = None
dict_value[key] = value_tmp
value = dict_value
self.config[name] = value_config if value_config is not None else "None"
if value == "None":
value = None
return value
[docs]
def config(key: str | None = None):
"""
Attach a KonfAI configuration key to a class or callable.
Parameters
----------
key : str | None, optional
Configuration branch handled by the decorated object.
Returns
-------
Callable
Decorator storing the key on the decorated object.
"""
def decorator(function):
function._key = key if key is not None else function.__name__
return function
return decorator
_CONFIG_PRIMITIVE_TYPES = {
int,
str,
bool,
float,
torch.Tensor,
}
_CONFIG_SUPPORTED_TYPES_MESSAGE = (
"Config: The config only supports types : config(Object), int, str, "
"bool, float, list[int], list[str], list[bool], list[float], "
"dict[str, Object]"
)
def _resolve_annotation(function, annotation):
if annotation in {"int", "float", "bool", "str"}:
return {"int": int, "float": float, "bool": bool, "str": str}[annotation]
if not isinstance(annotation, str):
return annotation
try:
return eval( # nosec B307
annotation,
{
**getattr(function, "__globals__", {}),
"Any": Any,
"Literal": Literal,
"Sequence": Sequence,
"Union": Union,
"bool": bool,
"dict": dict,
"float": float,
"int": int,
"list": list,
"str": str,
"torch": torch,
"tuple": tuple,
"typing": typing,
},
)
except Exception:
return annotation
def _unwrap_optional(annotation):
origin = get_origin(annotation)
if origin not in {Union, types.UnionType}:
return annotation
args = [arg for arg in get_args(annotation) if arg not in {type(None), types.NoneType}]
if len(args) == 1:
return args[0]
if len(args) > 1:
return args[0]
return annotation
def _convert_union_sequence_value(
value: object,
valid_types: tuple[type | object, ...],
param_name: str,
) -> object:
converted = None
last_error: Exception | None = None
for candidate_type in valid_types:
try:
if candidate_type is Any:
return value
if candidate_type in {type(None), types.NoneType}:
if value in (None, "None"):
return None
continue
if not isinstance(candidate_type, type):
continue
current_value = (
torch.tensor(value) if candidate_type == torch.Tensor and not isinstance(value, torch.Tensor) else value
)
converted = current_value if candidate_type == torch.Tensor else candidate_type(current_value)
break
except Exception as exc:
last_error = exc
if converted is None and value not in (None, "None"):
raise ConfigError(
f"Invalid value '{value}' for parameter '{param_name}'.",
f"Expected one of: {valid_types}.",
f"Last conversion error: {last_error}" if last_error else "",
)
return converted
[docs]
def apply_config(konfai_args: str | None = None):
"""
Recursively instantiate callables from the active KonfAI configuration.
Parameters
----------
konfai_args : str | None, optional
Root configuration path used to resolve nested constructor arguments.
Returns
-------
Callable
Decorator that injects configuration-backed arguments at call time.
"""
def decorator(function):
def new_function(*args, **kwargs):
key = getattr(function, "_key", None)
key_tmp = konfai_args + ("." + key if key is not None else "") if konfai_args is not None else key
if (
"KONFAI_config_file" in os.environ
and "KONFAI_CONFIG_MODE" in os.environ
and os.environ["KONFAI_CONFIG_MODE"] != "Import"
and key_tmp is not None
):
previous_path = os.environ.get("KONFAI_CONFIG_PATH")
os.environ["KONFAI_CONFIG_PATH"] = key_tmp
without = kwargs["konfai_without"] if "konfai_without" in kwargs else []
try:
with Config(key_tmp) as config:
if not isinstance(config.config, collections.abc.Mapping):
return None
kwargs = {}
params = list(inspect.signature(function).parameters.values())
for param in params[len(args) :]:
if param.name in without:
continue
annotation = _resolve_annotation(function, param.annotation)
if get_origin(annotation) is Literal:
allowed_values = get_args(annotation)
default_value = param.default if param.default != inspect._empty else allowed_values[0]
value = config.get_value(
param.name,
f"default|{default_value}",
)
if value not in allowed_values:
raise ConfigError(
f"Invalid value '{value}' for "
f"parameter '{param.name}' expected "
f"one of: {allowed_values}."
)
kwargs[param.name] = value
continue
annotation = _unwrap_optional(annotation)
if annotation == inspect._empty:
if param.name != "self":
kwargs[param.name] = config.get_value(
param.name,
param.default,
)
continue
if annotation in _CONFIG_PRIMITIVE_TYPES or annotation is Any:
kwargs[param.name] = config.get_value(
param.name,
param.default,
)
continue
origin = get_origin(annotation)
if origin in {list, tuple, Sequence, collections.abc.Sequence}:
values = config.get_value(
param.name,
param.default,
)
if values is None:
kwargs[param.name] = None
continue
args_annotation = get_args(annotation)
elem_type = args_annotation[0] if args_annotation else Any
elem_origin = get_origin(elem_type)
if elem_origin in {Union, types.UnionType}:
valid_types = get_args(elem_type)
kwargs[param.name] = [
_convert_union_sequence_value(value, valid_types, param.name)
for value in values
]
elif elem_type in {int, str, bool, float, torch.Tensor, Any}:
kwargs[param.name] = values
else:
raise ConfigError(_CONFIG_SUPPORTED_TYPES_MESSAGE)
continue
if origin is dict:
key_type, value_type = get_args(annotation)
if key_type is not str:
raise ConfigError(_CONFIG_SUPPORTED_TYPES_MESSAGE)
values = config.get_value(
param.name,
param.default,
)
if values is None or value_type in {
int,
str,
bool,
float,
Any,
}:
kwargs[param.name] = values
continue
try:
kwargs[param.name] = {
value: apply_config(f"{key_tmp}.{param.name}.{value}")(value_type)()
for value in values
}
except Exception as exc:
raise ConfigError(f"{values} {exc}") from exc
continue
try:
kwargs[param.name] = apply_config(key_tmp)(annotation)()
except Exception as exc:
raise ConfigError(
"Failed to instantiate " f"{param.name} with type " f"{annotation}, error {exc}"
) from exc
return function(*args, **kwargs)
finally:
if previous_path is None:
os.environ.pop("KONFAI_CONFIG_PATH", None)
else:
os.environ["KONFAI_CONFIG_PATH"] = previous_path
return function(*args, **kwargs)
return new_function
return decorator