Source code for konfai.main

#!/usr/bin/env python3
#
# Copyright (c) 2025 Valentin Boussot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Command-line entrypoints for KonfAI workflows, apps, and services."""

import argparse
import importlib.metadata
import os
import sys

from konfai import cuda_visible_devices
from konfai.utils.runtime import State

_cwd = os.getcwd()
if _cwd not in sys.path:
    sys.path.insert(0, _cwd)


def _run(parser: argparse.ArgumentParser) -> None:
    """
    Shared CLI builder and dispatcher for the main KonfAI training/inference commands.

    This function:
    1) defines common arguments used by TRAIN / RESUME / PREDICTION / EVALUATION
       (config file, overwrite, device selection, quiet, tensorboard)
    2) defines subcommands and their command-specific arguments
    3) parses CLI args and dispatches to the correct implementation:
       - `konfai.trainer.train` for TRAIN and RESUME
       - `konfai.predictor.predict` for PREDICTION
       - `konfai.evaluator.evaluate` for EVALUATION

    Device selection
    ----------------
    GPU and CPU are mutually exclusive:
    - `--gpu` accepts one or more GPU ids, constrained to available devices
      returned by `cuda_visible_devices()`.
    - `--cpu` accepts a strictly positive integer (number of workers).

    Config handling
    ---------------
    If `--config` is omitted, the `config` key is removed from the argument dict,
    so downstream functions can use their own default config filename.

    Parameters
    ----------
    parser : argparse.ArgumentParser
        The top-level parser created by the caller.
    """

    def add_common_args(parser: argparse.ArgumentParser):
        parser.add_argument(
            "-c",
            "--config",
            type=str,
            default=None,
            help="Path to the configuration file (YAML). "
            "If omitted, a command-specific default is used (e.g. Train.yml, Prediction.yml, Evaluation.yml).",
        )
        parser.add_argument(
            "-y",
            "--overwrite",
            action="store_true",
            help="Overwrite existing outputs (checkpoints, logs, predictions) without prompting.",
        )

        device_group = parser.add_mutually_exclusive_group()
        devices = cuda_visible_devices()
        device_group.add_argument(
            "--gpu",
            type=int,
            nargs="+",
            choices=devices,
            default=[],
            help="GPU device ids to use, e.g. '0' or '0,1,2'. If omitted runs on CPU.",
        )

        def non_negative_int(value: str) -> int:
            ivalue = int(value)
            if ivalue <= 0:
                raise argparse.ArgumentTypeError("CPU value must be > 0")
            return ivalue

        device_group.add_argument(
            "--cpu",
            type=non_negative_int,
            default=None,
            help="Run on CPU using N worker processes/cores. If omitted, uses GPU when available.",
        )

        parser.add_argument(
            "-q", "--quiet", action="store_true", help="Suppress console output for a quieter execution"
        )
        parser.add_argument("-tb", "--tensorboard", action="store_true", help="Launch TensorBoard.")

    subparsers = parser.add_subparsers(dest="command", required=True)

    train_p = subparsers.add_parser(str(State.TRAIN), help="Train a model from scratch.")
    add_common_args(train_p)
    train_p.add_argument(
        "--checkpoints-dir",
        "--checkpoints_dir",
        type=str,
        default="./Checkpoints/",
        help="Directory where checkpoints are saved (default: ./Checkpoints/).",
    )

    train_p.add_argument(
        "--statistics-dir",
        "--statistics_dir",
        type=str,
        default="./Statistics/",
        help="Directory where training statistics/logs are saved (default: ./Statistics/).",
    )

    resume_p = subparsers.add_parser(str(State.RESUME), help="Resume training from existing checkpoints.")
    add_common_args(resume_p)
    resume_p.add_argument(
        "--model",
        type=str,
        required=True,
        help="Checkpoint path to resume from",
    )

    resume_p.add_argument(
        "-checkpoints-dir",
        "-checkpoints_dir",
        type=str,
        default="./Checkpoints/",
        help="Directory where checkpoints are saved (default: ./Checkpoints/)",
    )

    resume_p.add_argument(
        "-statistics-dir",
        "-statistics_dir",
        type=str,
        default="./Statistics/",
        help="Directory where training statistics/logs are saved (default: ./Statistics/).",
    )

    predict_p = subparsers.add_parser(str(State.PREDICTION), help="Run inference using a trained model.")
    add_common_args(predict_p)

    predict_p.add_argument(
        "--models",
        type=str,
        nargs="+",
        metavar="PATH",
        required=True,
        help="One or more checkpoint/model paths to resume from.",
    )

    predict_p.add_argument(
        "--predictions-dir",
        "--predictions_dir",
        type=str,
        default="./Predictions/",
        help="Directory where predictions are written (default: ./Predictions/).",
    )

    eval_p = subparsers.add_parser(str(State.EVALUATION), help="Evaluate model.")
    add_common_args(eval_p)

    eval_p.add_argument(
        "--evaluations-dir",
        "--evaluations_dir",
        type=str,
        default="./Evaluations/",
        help="Directory where evaluation outputs are written (default: ./Evaluations/).",
    )

    parser.add_argument(
        "--version",
        action="version",
        version=importlib.metadata.version("konfai"),
        help="Print KonfAI version and exit.",
    )

    args = vars(parser.parse_args())

    if args["command"] == "PREDICTION":
        from konfai.predictor import predict

        if args["config"] is not None:
            args["prediction_file"] = args.pop("config")
        predict(**args)
    elif args["command"] == "EVALUATION":
        from konfai.evaluator import evaluate

        if args["config"] is not None:
            args["evaluations_file"] = args.pop("config")

        evaluate(**args)
    else:
        from konfai.trainer import train

        if args["config"] is None:
            del args["config"]
        train(**args)


[docs] def main(): """ Entry point for the ``konfai`` command-line interface. This function builds the top-level CLI parser and delegates the full argument parsing and command dispatching to `_run(parser)`. Supported commands are: - TRAIN - RESUME - PREDICTION - EVALUATION Notes ----- The actual execution logic is implemented in `konfai.trainer.train`, `konfai.predictor.predict`, and `konfai.evaluator.evaluate`. """ parser = argparse.ArgumentParser( prog="konfAI", description="KonfAI – Deep learning framework for Medical AI Models", allow_abbrev=False ) _run(parser)
[docs] def cluster(): """ Entry point for running KonfAI with cluster-oriented CLI arguments. This command extends the standard KonfAI CLI with a "Cluster manager arguments" group (job name, nodes, memory, time limit, resubmit), then delegates parsing and command dispatching to `_run(parser)`. Notes ----- - This function only defines extra CLI arguments before delegating to ``_run``. """ parser = argparse.ArgumentParser( prog="konfAI", description="KonfAI – Deep learning framework for Medical AI Models", allow_abbrev=False ) # Cluster manager arguments cluster_args = parser.add_argument_group("Cluster manager arguments") cluster_args.add_argument("--name", type=str, help="Task name", required=True) cluster_args.add_argument("--num-nodes", "--num_nodes", default=1, type=int, help="Number of nodes") cluster_args.add_argument("--memory", type=int, default=16, help="Amount of memory per node") cluster_args.add_argument( "--time-limit", "--time_limit", type=int, default=1440, help="Job time limit in minute", ) cluster_args.add_argument( "--resubmit", action="store_true", help="Automatically resubmit job just before timout", ) _run(parser)