diff --git a/docs/hirundo.cli_dataset_qa.rst b/docs/hirundo.cli_dataset_qa.rst new file mode 100644 index 00000000..2144fdf0 --- /dev/null +++ b/docs/hirundo.cli_dataset_qa.rst @@ -0,0 +1,10 @@ +.. meta:: + :http-equiv=Content-Security-Policy: default-src 'self', frame-ancestors 'none' + +hirundo.cli_dataset_qa module +============================= + +.. automodule:: hirundo.cli_dataset_qa + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/hirundo.cli_eval.rst b/docs/hirundo.cli_eval.rst new file mode 100644 index 00000000..85c429f3 --- /dev/null +++ b/docs/hirundo.cli_eval.rst @@ -0,0 +1,10 @@ +.. meta:: + :http-equiv=Content-Security-Policy: default-src 'self', frame-ancestors 'none' + +hirundo.cli_eval module +======================= + +.. automodule:: hirundo.cli_eval + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/hirundo.cli_unlearning.rst b/docs/hirundo.cli_unlearning.rst new file mode 100644 index 00000000..1b54735e --- /dev/null +++ b/docs/hirundo.cli_unlearning.rst @@ -0,0 +1,10 @@ +.. meta:: + :http-equiv=Content-Security-Policy: default-src 'self', frame-ancestors 'none' + +hirundo.cli_unlearning module +============================= + +.. automodule:: hirundo.cli_unlearning + :members: + :undoc-members: + :show-inheritance: diff --git a/hirundo/_cli_common.py b/hirundo/_cli_common.py new file mode 100644 index 00000000..c2557fb7 --- /dev/null +++ b/hirundo/_cli_common.py @@ -0,0 +1,73 @@ +import re +import sys +from collections.abc import Callable +from enum import Enum +from typing import Any + +import typer +from rich.console import Console +from rich.table import Table + +_RUN_ID_RE = re.compile(r"^[a-zA-Z0-9_-]+$") + +docs = "sphinx" in sys.modules +hirundo_epilog = ( + None + if docs + else "Made with ❤️ by Hirundo. Visit https://www.hirundo.io for more information." +) + +console = Console() + + +def make_app(name: str, help_text: str) -> typer.Typer: + return typer.Typer( + name=name, + no_args_is_help=True, + rich_markup_mode="rich", + epilog=hirundo_epilog, + help=help_text, + ) + + +def validate_run_id(run_id: str) -> str: + if not _RUN_ID_RE.fullmatch(run_id): + console.print( + f"[red]Invalid run ID '{run_id}'. " + "Run IDs may only contain alphanumeric characters, hyphens, and underscores.[/red]" + ) + raise typer.Exit(code=1) from None + return run_id + + +def validate_enum(value: str, enum_cls: type[Enum], label: str) -> Any: + try: + return enum_cls(value.upper()) + except ValueError: + valid = ", ".join(member.value for member in enum_cls) + console.print(f"[red]Invalid {label} '{value}'. Valid options: {valid}[/red]") + raise typer.Exit(code=1) from None + + +def wait_or_notify( + run_id: str, check_fn: Callable[[str], Any], cmd_name: str, wait: bool +) -> Any: + if wait: + return check_fn(run_id) + console.print( + f"Use [bold]hirundo {cmd_name} check[/bold] [italic][/italic] to monitor progress." + ) + return None + + +def print_runs_table( + title: str, + columns: tuple[str, ...], + rows: list[tuple[str | None, ...]], +) -> None: + table = Table(title=title, expand=True) + for col in columns: + table.add_column(col, overflow="fold") + for row in rows: + table.add_row(*row) + console.print(table) diff --git a/hirundo/cli.py b/hirundo/cli.py index 6714e596..0744df49 100644 --- a/hirundo/cli.py +++ b/hirundo/cli.py @@ -1,23 +1,16 @@ import os import re -import sys from pathlib import Path from typing import Annotated from urllib.parse import urlparse import typer -from rich.console import Console -from rich.table import Table +from hirundo._cli_common import docs, hirundo_epilog, validate_run_id from hirundo._env import API_HOST, EnvLocation - -docs = "sphinx" in sys.modules -hirundo_epilog = ( - None - if docs - else "Made with ❤️ by Hirundo. Visit https://www.hirundo.io for more information." -) - +from hirundo.cli_dataset_qa import dataset_qa_app, dataset_qa_list +from hirundo.cli_eval import eval_app +from hirundo.cli_unlearning import unlearning_app app = typer.Typer( name="hirundo", @@ -26,6 +19,10 @@ epilog=hirundo_epilog, ) +app.add_typer(eval_app, name="eval") +app.add_typer(dataset_qa_app, name="dataset-qa") +app.add_typer(unlearning_app, name="unlearning") + def _upsert_env(dotenv_filepath: str | Path, var_name: str, var_value: str): """ @@ -197,8 +194,9 @@ def check_run( """ from hirundo.dataset_qa import QADataset - results = QADataset.check_run_by_id(run_id) - print(f"Run results saved to {results.cached_zip_path}") + results = QADataset.check_run_by_id(validate_run_id(run_id)) + if results is not None: + print(f"Run results saved to {results.cached_zip_path}") @app.command("list-runs", epilog=hirundo_epilog) @@ -206,36 +204,7 @@ def list_runs(): """ List all runs available. """ - from hirundo.dataset_qa import QADataset - - runs = QADataset.list_runs() - - console = Console() - table = Table( - title="Runs:", - expand=True, - ) - cols = ( - "Dataset name", - "Run ID", - "Status", - "Created At", - "Run Args", - ) - for col in cols: - table.add_column( - col, - overflow="fold", - ) - for run in runs: - table.add_row( - str(run.name), - str(run.id), - str(run.status), - run.created_at.isoformat(), - run.run_args.model_dump_json() if run.run_args else None, - ) - console.print(table) + dataset_qa_list(archived=False) typer_click_object = typer.main.get_command(app) diff --git a/hirundo/cli_dataset_qa.py b/hirundo/cli_dataset_qa.py new file mode 100644 index 00000000..50c7b2e9 --- /dev/null +++ b/hirundo/cli_dataset_qa.py @@ -0,0 +1,80 @@ +from typing import Annotated + +import typer + +from hirundo._cli_common import ( + console, + hirundo_epilog, + make_app, + print_runs_table, + validate_run_id, + wait_or_notify, +) + +dataset_qa_app = make_app("dataset-qa", "Launch and monitor Dataset QA runs.") + + +@dataset_qa_app.command("run", epilog=hirundo_epilog) +def dataset_qa_run( + dataset_id: Annotated[int, typer.Argument(help="ID of the dataset to run QA on.")], + wait: Annotated[ + bool, + typer.Option( + "--wait/--no-wait", help="Wait for the run to complete and stream progress." + ), + ] = True, +): + """ + Launch a Dataset QA run on the dataset with the given ID. + """ + from hirundo.dataset_qa import QADataset + + run_id = QADataset.launch_qa_run(dataset_id) + console.print(f"Dataset QA run started. Run ID: [bold]{run_id}[/bold]") + + results = wait_or_notify(run_id, QADataset.check_run_by_id, "dataset-qa", wait) + if results is not None: + console.print(f"Run results saved to {results.cached_zip_path}") + + +@dataset_qa_app.command("list", epilog=hirundo_epilog) +def dataset_qa_list( + archived: Annotated[ + bool, + typer.Option("--archived/--no-archived", help="Include archived runs."), + ] = False, +): + """ + List Dataset QA runs. + """ + from hirundo.dataset_qa import QADataset + + runs = QADataset.list_runs(archived=archived) + print_runs_table( + "Dataset QA Runs:", + ("Dataset Name", "Run ID", "Status", "Created At", "Run Args"), + [ + ( + str(run.name), + str(run.run_id), + str(run.status), + run.created_at.isoformat(), + run.run_args.model_dump_json() if run.run_args else None, + ) + for run in runs + ], + ) + + +@dataset_qa_app.command("check", epilog=hirundo_epilog) +def dataset_qa_check( + run_id: Annotated[str, typer.Argument(help="The run ID to check.")], +): + """ + Check the status of a Dataset QA run and stream progress. + """ + from hirundo.dataset_qa import QADataset + + results = QADataset.check_run_by_id(validate_run_id(run_id)) + if results is not None: + console.print(f"Run results saved to {results.cached_zip_path}") diff --git a/hirundo/cli_eval.py b/hirundo/cli_eval.py new file mode 100644 index 00000000..6ddc7863 --- /dev/null +++ b/hirundo/cli_eval.py @@ -0,0 +1,129 @@ +from typing import Annotated + +import typer + +from hirundo._cli_common import ( + console, + hirundo_epilog, + make_app, + print_runs_table, + validate_enum, + validate_run_id, + wait_or_notify, +) + +eval_app = make_app("eval", "Launch and monitor LLM behavior evaluation runs.") + + +@eval_app.command("run", epilog=hirundo_epilog) +def eval_run( + preset: Annotated[ + str, + typer.Option( + "--preset", + help="Evaluation preset. One of: BBQ_BIAS, BBQ_UNBIAS, UNQOVER_BIAS, HALU_EVAL, MED_HALLU, INJECTION_EVAL", + ), + ], + model_id: Annotated[ + int | None, + typer.Option("--model-id", help="ID of the LLM model to evaluate."), + ] = None, + source_run_id: Annotated[ + str | None, + typer.Option("--source-run-id", help="ID of the unlearning run to evaluate."), + ] = None, + name: Annotated[ + str | None, + typer.Option("--name", help="Optional name for this evaluation run."), + ] = None, + wait: Annotated[ + bool, + typer.Option( + "--wait/--no-wait", help="Wait for the run to complete and stream progress." + ), + ] = True, +): + """ + Launch an LLM behavior evaluation run. + + Either --model-id or --source-run-id must be provided. + """ + from hirundo.llm_behavior_eval import ( + EvalRunInfo, + LlmBehaviorEval, + ModelOrRun, + PresetType, + ) + + if model_id is None and source_run_id is None: + console.print( + "[red]Error: either --model-id or --source-run-id must be provided.[/red]" + ) + raise typer.Exit(code=1) + if model_id is not None and source_run_id is not None: + console.print( + "[red]Error: only one of --model-id or --source-run-id may be provided.[/red]" + ) + raise typer.Exit(code=1) + + if source_run_id is not None: + source_run_id = validate_run_id(source_run_id) + + preset_type = validate_enum(preset, PresetType, "preset") + model_or_run = ModelOrRun.MODEL if model_id is not None else ModelOrRun.RUN + run_info = EvalRunInfo( + model_id=model_id, + source_run_id=source_run_id, + preset_type=preset_type, + name=name, + ) + + run_id = LlmBehaviorEval.launch_eval_run(model_or_run, run_info) + console.print(f"Eval run started. Run ID: [bold]{run_id}[/bold]") + + results = wait_or_notify(run_id, LlmBehaviorEval.check_run_by_id, "eval", wait) + if results is not None: + console.print(f"Run results saved to {results.cached_zip_path}") + + +@eval_app.command("list", epilog=hirundo_epilog) +def eval_list( + archived: Annotated[ + bool, + typer.Option("--archived/--no-archived", help="Include archived runs."), + ] = False, +): + """ + List LLM behavior evaluation runs. + """ + from hirundo.llm_behavior_eval import LlmBehaviorEval + + runs = LlmBehaviorEval.list_runs(archived=archived) + print_runs_table( + "Eval Runs:", + ("Run ID", "Name", "Status", "Preset", "Created At"), + [ + ( + str(run.run_id), + str(run.name), + str(run.status), + run.preset_type.value if run.preset_type else None, + run.created_at.isoformat(), + ) + for run in runs + ], + ) + + +@eval_app.command("check", epilog=hirundo_epilog) +def eval_check( + run_id: Annotated[str, typer.Argument(help="The run ID to check.")], +): + """ + Check the status of an LLM behavior evaluation run and stream progress. + """ + from hirundo.llm_behavior_eval import LlmBehaviorEval + + results = LlmBehaviorEval.check_run_by_id(validate_run_id(run_id)) + if results is not None: + console.print(f"Run results saved to {results.cached_zip_path}") diff --git a/hirundo/cli_unlearning.py b/hirundo/cli_unlearning.py new file mode 100644 index 00000000..c7df1495 --- /dev/null +++ b/hirundo/cli_unlearning.py @@ -0,0 +1,134 @@ +from typing import Annotated + +import typer + +from hirundo._cli_common import ( + console, + hirundo_epilog, + make_app, + print_runs_table, + validate_enum, + validate_run_id, + wait_or_notify, +) + +unlearning_app = make_app("unlearning", "Launch and monitor LLM unlearning runs.") + + +@unlearning_app.command("run", epilog=hirundo_epilog) +def unlearning_run( + model_id: Annotated[int, typer.Argument(help="ID of the LLM model to unlearn.")], + bias_type: Annotated[ + str | None, + typer.Option( + "--bias-type", + help="Bias type for unlearning. One of: ALL, RACE, NATIONALITY, GENDER, PHYSICAL_APPEARANCE, RELIGION, AGE", + ), + ] = None, + hallucination_type: Annotated[ + str | None, + typer.Option( + "--hallucination-type", + help="Hallucination type for unlearning. One of: GENERAL, MEDICAL, LEGAL, DEFENSE", + ), + ] = None, + name: Annotated[ + str | None, + typer.Option("--name", help="Optional name for this unlearning run."), + ] = None, + wait: Annotated[ + bool, + typer.Option( + "--wait/--no-wait", help="Wait for the run to complete and stream progress." + ), + ] = True, +): + """ + Launch an LLM unlearning run. + + Exactly one of --bias-type or --hallucination-type must be provided. + """ + from hirundo.llm_bias_type import BBQBiasType + from hirundo.unlearning_llm import ( + BiasBehavior, + DefaultUtility, + HallucinationBehavior, + HallucinationType, + LlmRunInfo, + LlmUnlearningRun, + ) + + if bias_type is None and hallucination_type is None: + console.print( + "[red]Error: either --bias-type or --hallucination-type must be provided.[/red]" + ) + raise typer.Exit(code=1) + if bias_type is not None and hallucination_type is not None: + console.print( + "[red]Error: only one of --bias-type or --hallucination-type may be provided.[/red]" + ) + raise typer.Exit(code=1) + + if bias_type is not None: + target_behavior = BiasBehavior( + bias_type=validate_enum(bias_type, BBQBiasType, "bias type") + ) + elif hallucination_type is not None: + target_behavior = HallucinationBehavior( + hallucination_type=validate_enum( + hallucination_type, HallucinationType, "hallucination type" + ) + ) + else: + raise typer.Exit(code=1) from None + + run_info = LlmRunInfo( + name=name, + target_behaviors=[target_behavior], + target_utilities=[DefaultUtility()], + ) + + run_id = LlmUnlearningRun.launch(model_id, run_info) + console.print(f"Unlearning run started. Run ID: [bold]{run_id}[/bold]") + + wait_or_notify(run_id, LlmUnlearningRun.check_run_by_id, "unlearning", wait) + + +@unlearning_app.command("list", epilog=hirundo_epilog) +def unlearning_list( + archived: Annotated[ + bool, + typer.Option("--archived/--no-archived", help="Include archived runs."), + ] = False, +): + """ + List LLM unlearning runs. + """ + from hirundo.unlearning_llm import LlmUnlearningRun + + runs = LlmUnlearningRun.list(archived=archived) + print_runs_table( + "Unlearning Runs:", + ("Name", "Run ID", "Status", "Created At"), + [ + ( + str(run.name), + str(run.run_id), + str(run.status), + run.created_at.isoformat(), + ) + for run in runs + ], + ) + + +@unlearning_app.command("check", epilog=hirundo_epilog) +def unlearning_check( + run_id: Annotated[str, typer.Argument(help="The run ID to check.")], +): + """ + Check the status of an LLM unlearning run and stream progress. + """ + from hirundo.unlearning_llm import LlmUnlearningRun + + LlmUnlearningRun.check_run_by_id(validate_run_id(run_id)) diff --git a/tests/test_cli_common.py b/tests/test_cli_common.py new file mode 100644 index 00000000..3b18d59d --- /dev/null +++ b/tests/test_cli_common.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock, patch + +import hirundo._cli_common as cli_common # noqa: E402 +import pytest +import typer +from hirundo._cli_common import validate_run_id, wait_or_notify + + +class TestValidateRunId: + def test_valid_id_returned_unchanged(self): + assert validate_run_id("abc-123_XYZ") == "abc-123_XYZ" + + @pytest.mark.parametrize( + "bad_id", ["run/id", "run\\id", "run id", "run\nid", "run.id", ""] + ) + def test_invalid_id_exits(self, bad_id): + with pytest.raises(typer.Exit) as exc: + validate_run_id(bad_id) + assert exc.value.exit_code == 1 + + def test_invalid_id_prints_message(self, bad_id="bad id"): + with ( + patch.object(cli_common.console, "print") as mock_print, + pytest.raises(typer.Exit), + ): + validate_run_id(bad_id) + output = mock_print.call_args[0][0] + assert "bad id" in output + assert "may only contain" in output + + +class TestWaitOrNotify: + def test_wait_true_calls_check_fn_and_returns_result(self): + check_fn = MagicMock(return_value="result") + assert wait_or_notify("run-1", check_fn, "dataset-qa", wait=True) == "result" + check_fn.assert_called_once_with("run-1") + + def test_wait_false_returns_none_without_calling_check_fn(self): + check_fn = MagicMock() + assert wait_or_notify("run-1", check_fn, "dataset-qa", wait=False) is None + check_fn.assert_not_called() + + def test_wait_false_prints_check_hint(self): + with patch.object(cli_common.console, "print") as mock_print: + wait_or_notify("run-1", MagicMock(), "dataset-qa", wait=False) + output = mock_print.call_args[0][0] + assert "dataset-qa check" in output