Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ Activate the local virtualenv before running any Python/uv commands: `source .ve

- Frameworks: `pytest` and `pytest-asyncio`.
- Place tests in `tests/`; name files `test_*.py`.
- Write tests in `pytest` style; do not add new `unittest`-based tests or `unittest`
assertions/fixtures.
- Run locally with `pytest` before opening a PR (CI runs lint + integration tests).

## Pull Request Guidelines
Expand Down
13 changes: 13 additions & 0 deletions hirundo/_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ class EnvLocation(enum.Enum):
API_KEY = os.getenv("API_KEY")


def get_env_bool(variable_name: str, default: bool = False) -> bool:
variable_value = os.getenv(variable_name)
if variable_value is None:
return default

normalized_value = variable_value.strip().lower()
if normalized_value in {"1", "true", "yes", "on"}:
return True
if normalized_value in {"0", "false", "no", "off"}:
return False
return default
Comment thread
benglewis marked this conversation as resolved.


def check_api_key():
if not API_KEY:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions hirundo/_llm_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class HuggingFaceTransformersModelOutput(BaseModel):
ModelSourceType.HUGGINGFACE_TRANSFORMERS
)
model_name: str
token: str | None = None


class LocalTransformersModel(BaseModel):
Expand Down
149 changes: 149 additions & 0 deletions hirundo/_model_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from pathlib import Path

from huggingface_hub import HfApi
from huggingface_hub.errors import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
)
from requests import HTTPError

from hirundo._hirundo_error import HirundoError
from hirundo.logger import get_logger

logger = get_logger(__name__)


def _build_huggingface_access_message(
model_name: str,
model_role: str,
hint: str,
token_provided: bool,
) -> str:
message_prefix = f"The {model_role} model '{model_name}'"

if hint == "gated":
if token_provided:
return (
f"{message_prefix} is gated and the provided HuggingFace token does not "
"have access. Please request access or use a different token."
)
return f"{message_prefix} is gated. Please provide a HuggingFace token with access."

if hint == "not_found":
if token_provided:
return (
f"{message_prefix} was not found or is private/gated for the provided "
"token. Please verify the model ID or token access."
)
return (
f"{message_prefix} was not found or is private/gated. Please provide a "
"HuggingFace token or verify the model ID."
)

if hint == "unauthorized":
if token_provided:
return (
f"{message_prefix} could not be accessed with the provided HuggingFace "
"token. Please verify token permissions or use a different model."
)
return (
f"{message_prefix} could not be accessed without a HuggingFace token. "
"Please provide a token or use a public model."
)

return (
f"{message_prefix} could not be accessed. Please verify the model ID or provide "
"a HuggingFace token with access."
)


def _is_local_model_path(path_or_repo_id: str) -> bool:
potential_path = Path(path_or_repo_id).expanduser()
return potential_path.exists()


def _get_huggingface_error_status_code(
exception: HfHubHTTPError | HTTPError,
) -> int | None:
response = exception.response
return response.status_code if response is not None else None


def validate_huggingface_model_access(
model_name: str,
token: str | None,
model_role: str,
) -> None:
"""Validate that a Hugging Face model can be accessed.

Args:
model_name: Hugging Face repository ID for the model to validate.
token: Optional Hugging Face access token used for authenticated access.
model_role: Human-readable role for the model in error messages.
"""
huggingface_api = HfApi(token=token)
token_provided = token is not None

try:
huggingface_api.model_info(repo_id=model_name)
except GatedRepoError as exception:
raise HirundoError(
_build_huggingface_access_message(
model_name=model_name,
model_role=model_role,
hint="gated",
token_provided=token_provided,
)
) from exception
except RepositoryNotFoundError as exception:
raise HirundoError(
_build_huggingface_access_message(
model_name=model_name,
model_role=model_role,
hint="not_found",
token_provided=token_provided,
)
) from exception
except (HfHubHTTPError, HTTPError) as exception:
status_code = _get_huggingface_error_status_code(exception)

if status_code in {401, 403}:
hint = "unauthorized"
else:
hint = "generic"
logger.debug(
"HuggingFace access validation failed for %s model '%s' with status %s.",
model_role,
model_name,
status_code,
)
raise HirundoError(
_build_huggingface_access_message(
model_name=model_name,
model_role=model_role,
hint=hint,
token_provided=token_provided,
)
) from exception
Comment thread
benglewis marked this conversation as resolved.


def validate_judge_model_access(path_or_repo_id: str, token: str | None) -> None:
"""Validate that a judge model can be accessed.

Args:
path_or_repo_id: Local filesystem path or Hugging Face repository ID for the
judge model.
token: Optional Hugging Face access token used when the judge model is hosted
on Hugging Face.
"""
if _is_local_model_path(path_or_repo_id):
return

validate_huggingface_model_access(
model_name=path_or_repo_id,
token=token,
model_role="judge",
)
36 changes: 35 additions & 1 deletion hirundo/llm_behavior_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from hirundo._http import raise_for_status_with_reason, requests
from hirundo._iter_sse_retrying import aiter_sse_retrying, iter_sse_retrying
from hirundo._llm_sources import HuggingFaceTransformersModelOutput, LlmSourcesOutput
from hirundo._model_access import (
validate_huggingface_model_access,
validate_judge_model_access,
)
from hirundo._run_checking import (
DEFAULT_MAX_RETRIES,
STATUS_TO_PROGRESS_MAP,
Expand All @@ -29,6 +33,7 @@
from hirundo.llm_behavior_eval_results import LlmBehaviorEvalResults
from hirundo.llm_bias_type import BBQBiasType, UnqoverBiasType
from hirundo.logger import get_logger
from hirundo.unlearning_llm import LlmModel
from hirundo.unzip import download_and_extract_llm_behavior_eval_zip

logger = get_logger(__name__)
Expand Down Expand Up @@ -64,6 +69,8 @@ class JudgeModel(BaseModel):


class EvalRunInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))

organization_id: int | None = None
name: str | None = None
model_id: int | None = None
Expand All @@ -75,7 +82,9 @@ class EvalRunInfo(BaseModel):


class OutputLlm(BaseModel):
model_config = {"extra": "allow"}
model_config = ConfigDict(
extra="allow", protected_namespaces=("model_validate", "model_dump")
)

id: int
organization_id: int
Expand Down Expand Up @@ -116,6 +125,8 @@ class LlmEvalMetrics(BaseModel):


class EvalRunRecord(BaseModel):
model_config = ConfigDict(protected_namespaces=("model_validate", "model_dump"))

id: int
name: str
model_id: int | None
Expand Down Expand Up @@ -143,6 +154,27 @@ class LlmBehaviorEval:
def __init__(self, run_id: str | None = None):
self.run_id = run_id

@staticmethod
def _validate_model_access(model_or_run: ModelOrRun, run_info: EvalRunInfo) -> None:
if run_info.judge_model is not None:
validate_judge_model_access(
path_or_repo_id=run_info.judge_model.path_or_repo_id,
token=run_info.judge_model.token,
)
if model_or_run == ModelOrRun.MODEL and run_info.model_id is None:
raise HirundoLlmBehaviorEvalError(
"model_id is required when model_or_run is 'model'"
)

if model_or_run == ModelOrRun.MODEL and run_info.model_id is not None:
llm_model = LlmModel.get_by_id(run_info.model_id)
if isinstance(llm_model.model_source, HuggingFaceTransformersModelOutput):
validate_huggingface_model_access(
model_name=llm_model.model_source.model_name,
Comment thread
benglewis marked this conversation as resolved.
token=llm_model.model_source.token,
model_role="LLM",
)
Comment thread
cursor[bot] marked this conversation as resolved.

@staticmethod
def _parse_eval_run_record(response_payload: dict) -> EvalRunRecord:
model_payload = response_payload.get("model")
Expand Down Expand Up @@ -219,6 +251,8 @@ def launch_eval_run(
else:
model_or_run_value = model_or_run

LlmBehaviorEval._validate_model_access(model_or_run_value, run_info)

response = requests.post(
f"{API_HOST}/llm-behavior-eval/run/{model_or_run_value.value}",
json=run_info.model_dump(mode="json"),
Expand Down
7 changes: 5 additions & 2 deletions hirundo/llm_behavior_eval_results.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import typing
from pathlib import Path

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

T = typing.TypeVar("T")


class LlmBehaviorEvalResults(BaseModel, typing.Generic[T]):
model_config = {"arbitrary_types_allowed": True}
model_config = ConfigDict(
arbitrary_types_allowed=True,
protected_namespaces=("model_validate", "model_dump"),
)

cached_zip_path: Path
"""
Expand Down
23 changes: 21 additions & 2 deletions hirundo/unlearning_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from hirundo._env import API_HOST
from hirundo._env import API_HOST, get_env_bool
from hirundo._headers import get_headers
from hirundo._http import raise_for_status_with_reason, requests
from hirundo._llm_pipeline import get_hf_pipeline_for_run_given_model
from hirundo._llm_sources import LlmSources, LlmSourcesOutput
from hirundo._llm_sources import (
HuggingFaceTransformersModel,
LlmSources,
LlmSourcesOutput,
)
from hirundo._model_access import validate_huggingface_model_access
from hirundo._run_checking import (
STATUS_TO_PROGRESS_MAP,
aiter_run_events,
Expand Down Expand Up @@ -48,7 +53,21 @@ class LlmModel(BaseModel):
def create(
self,
replace_if_exists: bool = False,
validate_hf_access: bool = True,
) -> int:
should_validate_hf_access = validate_hf_access or get_env_bool(
"HIRUNDO_VALIDATE_HF_ACCESS",
default=True,
)

if should_validate_hf_access and isinstance(
self.model_source, HuggingFaceTransformersModel
):
validate_huggingface_model_access(
model_name=self.model_source.model_name,
token=self.model_source.token,
model_role="LLM",
)
llm_model_response = requests.post(
f"{API_HOST}/unlearning-llm/llm/",
json={
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"stamina>=24.2.0",
"httpx-sse>=0.4.0",
"tqdm>=4.66.5",
"huggingface-hub>=1.0.0",
"h11>=0.16.0",
# ⬆️ Required to fix vulnerability GHSA-vqfr-h8mv-ghfj
"requests>=2.32.4",
Expand Down Expand Up @@ -66,8 +67,8 @@ dev = [
"basedpyright==1.37.1",
"virtualenv>=20.36.1",
# ⬆️ Needed for `pre-commit` version fix for vulnerability GHSA-rqc4-2hc7-8c8v
"authlib>=1.6.6",
# ⬆️ Required to fix vulnerability CVE-2025-68158
"authlib>=1.6.7",
# ⬆️ Required to fix vulnerability CVE-2026-28802
"ruff>=0.12.0",
"bumpver>=2025.1131",
"platformdirs>=4.3.6",
Expand Down
8 changes: 4 additions & 4 deletions tests/dataset_qa_shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from collections import defaultdict
from contextlib import contextmanager

Expand All @@ -9,6 +8,7 @@
RunArgs,
StorageConfig,
)
from hirundo._env import get_env_bool
from hirundo._run_status import RunStatus
from hirundo.logger import get_logger

Expand Down Expand Up @@ -168,8 +168,8 @@ def dataset_qa_sync_test(
run_args: RunArgs | None = None,
):
logger.info("Sync: Finished cleanup")
if (os.getenv("FULL_TEST", "false") == "true" and sanity) or (
alternative_env and os.getenv(alternative_env, "false") == "true"
if (get_env_bool("FULL_TEST") and sanity) or (
alternative_env and get_env_bool(alternative_env)
):
run_id = test_dataset.run_qa(replace_dataset_if_exists=True, run_args=run_args)
logger.info("Sync: Started dataset QA run with run ID %s", run_id)
Expand All @@ -189,7 +189,7 @@ async def dataset_qa_async_test(
run_args: RunArgs | None = None,
):
logger.info("Async: Finished cleanup")
if os.getenv(env, "false") == "true":
if get_env_bool(env):
run_id = test_dataset.run_qa(replace_dataset_if_exists=True, run_args=run_args)
logger.info("Async: Started dataset QA run with run ID %s", run_id)
events_generator = test_dataset.acheck_run()
Expand Down
Loading
Loading