Source code for ooai_llm.litellm_registry

"""Raw LiteLLM registry exploration helpers.

Purpose:
    Expose LiteLLM's full local pricing/context registry for exploratory CLI
    and TUI workflows without forcing every provider into ooai's canonical
    provider enum.

Design:
    - Keep this separate from ``model_defaults`` because factory defaults and
      runtime profiles intentionally operate on first-class supported
      providers.
    - Preserve arbitrary LiteLLM provider labels such as ``fireworks_ai``,
      ``bedrock``, ``openrouter``, or ``together_ai``.
    - Reuse the same filter vocabulary as the supported model catalog wherever
      the raw metadata is available.
"""

from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime, timezone
from decimal import Decimal
import importlib
import re
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, computed_field

from .model_defaults import ModelCapabilityName, ModelCatalogSortName

_CHAT_MODES = {"chat", "completion", "responses", "messages"}
_EXCLUDED_NAME_PARTS = (
    "audio",
    "babbage",
    "dall-e",
    "dalle",
    "davinci",
    "edit",
    "embedding",
    "embed",
    "image",
    "moderation",
    "realtime",
    "rerank",
    "sora",
    "speech",
    "transcribe",
    "translation",
    "tts",
    "whisper",
)
_SMALL_MODEL_WORDS = ("nano", "mini", "haiku", "flash", "lite", "small", "fast", "8b", "3b")
_EXPENSIVE_SPECIAL_WORDS = ("pro", "opus", "max", "ultra")
_REASONING_WORDS = ("reasoning", "think", "thinking", "magistral", "opus", "pro", "reasoner", "o1", "o3", "o4")
_CODING_WORDS = ("code", "codex", "codestral", "devstral", "coder")
_VISION_WORDS = ("vision", "visual", "pixtral", "vl", "multimodal")


[docs] class LiteLLMRegistryModel(BaseModel): """One raw LiteLLM registry row with normalized display metadata."""
[docs] model_config = ConfigDict(extra="forbid")
[docs] provider: str
[docs] model_key: str
[docs] model_id: str
[docs] source: str = "litellm_registry"
[docs] display_name: str | None = None
[docs] created: int | None = None
[docs] created_at: str | None = None
[docs] input_cost_per_token: Decimal | None = None
[docs] output_cost_per_token: Decimal | None = None
[docs] max_input_tokens: int | None = None
[docs] max_output_tokens: int | None = None
[docs] mode: str | None = None
[docs] supports_vision: bool | None = None
[docs] supports_function_calling: bool | None = None
[docs] supports_tool_choice: bool | None = None
[docs] supports_parallel_tool_calls: bool | None = None
[docs] supports_structured_output: bool | None = None
[docs] raw: dict[str, Any] = Field(default_factory=dict)
@computed_field # type: ignore[prop-decorator] @property
[docs] def model_string(self) -> str: """Return the LiteLLM registry key used to identify the model.""" return self.model_key
@computed_field # type: ignore[prop-decorator] @property
[docs] def release_date(self) -> str | None: """Return the best available release-date label for display.""" for key in ("release_date", "released_at", "created_at"): value = self.raw.get(key) if value: return str(value)[:10] if self.created_at: return self.created_at[:10] if self.created is not None: return datetime.fromtimestamp(self.created, tz=timezone.utc).date().isoformat() return _date_label_from_text(self.model_key)
@computed_field # type: ignore[prop-decorator] @property
[docs] def input_cost_per_1m_tokens(self) -> Decimal | None: """Return input-token cost normalized to one million tokens.""" if self.input_cost_per_token is None: return None return self.input_cost_per_token * Decimal(1_000_000)
@computed_field # type: ignore[prop-decorator] @property
[docs] def output_cost_per_1m_tokens(self) -> Decimal | None: """Return output-token cost normalized to one million tokens.""" if self.output_cost_per_token is None: return None return self.output_cost_per_token * Decimal(1_000_000)
@computed_field # type: ignore[prop-decorator] @property
[docs] def context_window(self) -> int | None: """Return the best known input/context window.""" return self.max_input_tokens or _coerce_int(self.raw.get("max_tokens"))
@computed_field # type: ignore[prop-decorator] @property
[docs] def supports_chat(self) -> bool: """Return whether this registry row looks usable for chat/generation.""" return _is_chat_like(self)
@computed_field # type: ignore[prop-decorator] @property
[docs] def supports_reasoning(self) -> bool: """Return whether this model appears reasoning-oriented.""" return _name_contains(self.model_key, _REASONING_WORDS) or _name_contains(self.model_key, _EXPENSIVE_SPECIAL_WORDS)
@computed_field # type: ignore[prop-decorator] @property
[docs] def supports_coding(self) -> bool: """Return whether this model appears coding-oriented.""" return _name_contains(self.model_key, _CODING_WORDS)
@computed_field # type: ignore[prop-decorator] @property
[docs] def supports_tool_calling(self) -> bool: """Return whether this row appears tool/function-call capable.""" return _coerce_bool_any( self.raw, ( "supports_tool_calling", "tool_calling", "supports_function_calling", "function_calling", "tools", ), ) is True or self.supports_function_calling is True
@computed_field # type: ignore[prop-decorator] @property
[docs] def capability_labels(self) -> list[str]: """Return display labels for inferred model capabilities.""" labels: list[str] = [] if self.supports_chat: labels.append("chat") if self.supports_reasoning: labels.append("reasoning") if self.supports_coding: labels.append("coding") if self.supports_vision is True or _name_contains(self.model_key, _VISION_WORDS): labels.append("vision") if self.supports_tool_calling: labels.append("function_calling") labels.append("tool_calling") if self.supports_tool_choice is True: labels.append("tool_choice") if self.supports_parallel_tool_calls is True: labels.append("parallel_tool_calls") if self.supports_structured_output is True: labels.append("structured_output") if _cheap_name_score(self.model_key) > 0: labels.append("cheap") return labels
[docs] class LiteLLMRegistryResult(BaseModel): """Result of listing raw LiteLLM registry rows."""
[docs] model_config = ConfigDict(extra="forbid")
[docs] models: list[LiteLLMRegistryModel] = Field(default_factory=list)
[docs] notes: list[str] = Field(default_factory=list)
[docs] def list_litellm_registry( *, providers: Iterable[str] | None = None, include_non_chat: bool = False, capabilities: Iterable[ModelCapabilityName] | None = None, min_context_tokens: int | None = None, min_output_tokens: int | None = None, max_input_cost_per_1m: Decimal | None = None, max_output_cost_per_1m: Decimal | None = None, released_after: str | None = None, released_before: str | None = None, sort_by: ModelCatalogSortName = "provider", strict: bool = False, ) -> LiteLLMRegistryResult: """List LiteLLM's raw local registry across arbitrary provider labels. Args: providers: Optional provider labels such as ``openrouter`` or ``fireworks_ai``. Unknown labels simply produce no rows. include_non_chat: Include embeddings, audio, image, rerank, and other non-chat registry entries. capabilities: Required capability labels. min_context_tokens: Optional minimum context/input-token window. min_output_tokens: Optional minimum output-token limit. max_input_cost_per_1m: Optional maximum input cost per one million tokens. max_output_cost_per_1m: Optional maximum output cost per one million tokens. released_after: Optional lower release-date bound. released_before: Optional upper release-date bound. sort_by: Sort mode for returned rows. strict: Raise when LiteLLM cannot be imported or inspected. Returns: Raw registry rows and explanatory notes. """ try: registry = _load_litellm_registry() except Exception as exc: if strict: raise RuntimeError(f"Could not load LiteLLM registry: {exc}") from exc return LiteLLMRegistryResult(models=[], notes=[f"Could not load LiteLLM registry: {exc}"]) provider_filters = _normalize_provider_filters(providers) capability_list = list(capabilities) if capabilities is not None else None released_after_score = _filter_date_score(released_after) released_before_score = _filter_date_score(released_before) models: list[LiteLLMRegistryModel] = [] for model_key, raw_value in registry.items(): raw = _coerce_plain_dict(raw_value) if not raw: continue model = _model_from_registry_entry(str(model_key), raw) if provider_filters and not _matches_provider_filter(model, provider_filters): continue if not include_non_chat and not model.supports_chat: continue if not _matches_filters( model, capabilities=capability_list, min_context_tokens=min_context_tokens, min_output_tokens=min_output_tokens, max_input_cost_per_1m=max_input_cost_per_1m, max_output_cost_per_1m=max_output_cost_per_1m, released_after_score=released_after_score, released_before_score=released_before_score, ): continue models.append(model) provider_count = len({model.provider for model in models}) notes = [ f"Loaded {len(models)} LiteLLM registry rows across {provider_count} provider labels.", "Raw LiteLLM registry rows are exploratory metadata; use supported-provider catalogs for factory defaults.", ] return LiteLLMRegistryResult(models=_sort_registry_models(models, sort_by=sort_by), notes=notes)
def _load_litellm_registry() -> Mapping[str, Any]: litellm_module = importlib.import_module("litellm") for attr in ("model_cost", "model_prices_and_context_window_json"): value = getattr(litellm_module, attr, None) if isinstance(value, Mapping): return value return {} def _model_from_registry_entry(model_key: str, raw: Mapping[str, Any]) -> LiteLLMRegistryModel: raw_dict = dict(raw) provider = _registry_provider(model_key, raw_dict) model_id = model_key.split("/", 1)[1] if "/" in model_key and model_key.split("/", 1)[0].lower() == provider else model_key return LiteLLMRegistryModel( provider=provider, model_key=model_key, model_id=model_id, display_name=str(raw_dict.get("display_name") or model_id), created=_coerce_int(raw_dict.get("created")), created_at=str(raw_dict.get("created_at")) if raw_dict.get("created_at") is not None else None, input_cost_per_token=_coerce_decimal(raw_dict.get("input_cost_per_token")), output_cost_per_token=_coerce_decimal(raw_dict.get("output_cost_per_token")), max_input_tokens=_coerce_int(raw_dict.get("max_input_tokens") or raw_dict.get("max_tokens")), max_output_tokens=_coerce_int(raw_dict.get("max_output_tokens") or raw_dict.get("max_output_tokens_per_response")), mode=str(raw_dict.get("mode")).lower() if raw_dict.get("mode") is not None else None, supports_vision=_coerce_bool_any(raw_dict, ("supports_vision", "vision")), supports_function_calling=_coerce_bool_any(raw_dict, ("supports_function_calling", "function_calling", "tools")), supports_tool_choice=_coerce_bool_any(raw_dict, ("supports_tool_choice", "tool_choice")), supports_parallel_tool_calls=_coerce_bool_any( raw_dict, ("supports_parallel_tool_calls", "supports_parallel_function_calling", "parallel_tool_calls"), ), supports_structured_output=_coerce_bool_any( raw_dict, ( "supports_structured_output", "structured_output", "supports_response_schema", "response_schema", "supports_json_schema", "json_schema", "supports_json_mode", "json_mode", ), ), raw=raw_dict, ) def _registry_provider(model_key: str, raw: Mapping[str, Any]) -> str: provider = raw.get("litellm_provider") or raw.get("provider") if provider: return str(provider).strip().lower() if "/" in model_key: return model_key.split("/", 1)[0].strip().lower() return "unknown" def _normalize_provider_filters(providers: Iterable[str] | None) -> set[str]: if providers is None: return set() return {provider.strip().lower() for provider in providers if provider.strip()} def _matches_provider_filter(model: LiteLLMRegistryModel, filters: set[str]) -> bool: model_key = model.model_key.lower() provider = model.provider.lower() return any( provider == item or (provider == "unknown" and model_key.startswith(f"{item}/")) for item in filters ) def _matches_filters( model: LiteLLMRegistryModel, *, capabilities: Iterable[ModelCapabilityName] | None, min_context_tokens: int | None, min_output_tokens: int | None, max_input_cost_per_1m: Decimal | None, max_output_cost_per_1m: Decimal | None, released_after_score: int | None, released_before_score: int | None, ) -> bool: if capabilities is not None and not all(_has_capability(model, capability) for capability in capabilities): return False if min_context_tokens is not None: context = model.context_window if context is None or context < min_context_tokens: return False if min_output_tokens is not None: output_limit = model.max_output_tokens if output_limit is None or output_limit < min_output_tokens: return False if max_input_cost_per_1m is not None: input_cost = model.input_cost_per_1m_tokens if input_cost is None or input_cost > max_input_cost_per_1m: return False if max_output_cost_per_1m is not None: output_cost = model.output_cost_per_1m_tokens if output_cost is None or output_cost > max_output_cost_per_1m: return False if released_after_score is not None or released_before_score is not None: release_score = _release_score(model) if release_score == 0: return False if released_after_score is not None and release_score < released_after_score: return False if released_before_score is not None and release_score > released_before_score: return False return True def _has_capability(model: LiteLLMRegistryModel, capability: ModelCapabilityName) -> bool: if capability == "chat": return model.supports_chat if capability == "reasoning": return model.supports_reasoning if capability == "coding": return model.supports_coding if capability == "vision": return model.supports_vision is True or _name_contains(model.model_key, _VISION_WORDS) if capability in {"function_calling", "tool_calling"}: return model.supports_tool_calling if capability == "tool_choice": return model.supports_tool_choice is True if capability == "parallel_tool_calls": return model.supports_parallel_tool_calls is True if capability == "structured_output": return model.supports_structured_output is True if capability == "cheap": return _cheap_name_score(model.model_key) > 0 raise ValueError(f"Unsupported model capability filter: {capability!r}.") def _sort_registry_models( models: list[LiteLLMRegistryModel], *, sort_by: ModelCatalogSortName, ) -> list[LiteLLMRegistryModel]: infinity = Decimal("Infinity") if sort_by == "provider": return sorted(models, key=lambda model: (model.provider, -_recency_score(model), model.model_key)) if sort_by == "model": return sorted(models, key=lambda model: model.model_key) if sort_by == "cost": return sorted(models, key=lambda model: (_model_cost(model) is None, _model_cost(model) or infinity)) if sort_by == "input_cost": return sorted(models, key=lambda model: (model.input_cost_per_1m_tokens is None, model.input_cost_per_1m_tokens or infinity)) if sort_by == "output_cost": return sorted(models, key=lambda model: (model.output_cost_per_1m_tokens is None, model.output_cost_per_1m_tokens or infinity)) if sort_by in {"context", "input_tokens"}: return sorted(models, key=lambda model: -(model.context_window or 0)) if sort_by == "output_tokens": return sorted(models, key=lambda model: -(model.max_output_tokens or 0)) return sorted(models, key=lambda model: (-_recency_score(model), model.provider, model.model_key)) def _is_chat_like(model: LiteLLMRegistryModel) -> bool: name = model.model_key.lower() if any(part in name for part in _EXCLUDED_NAME_PARTS): return False if model.mode is not None and model.mode not in _CHAT_MODES: return False return True def _model_cost(model: LiteLLMRegistryModel) -> Decimal | None: costs = [model.input_cost_per_token, model.output_cost_per_token] known = [cost for cost in costs if cost is not None] if not known: return None return sum(known, Decimal("0")) def _cheap_name_score(name: str) -> int: normalized = name.lower() return sum(1 for word in _SMALL_MODEL_WORDS if word in normalized) def _name_contains(name: str, words: Sequence[str]) -> bool: normalized = name.lower() return any(word in normalized for word in words) def _coerce_plain_dict(value: Any) -> dict[str, Any]: if value is None: return {} if isinstance(value, Mapping): return dict(value) if hasattr(value, "model_dump"): return dict(value.model_dump()) if hasattr(value, "dict"): return dict(value.dict()) if hasattr(value, "__dict__"): return {key: item for key, item in vars(value).items() if not key.startswith("_")} return {} def _coerce_decimal(value: Any) -> Decimal | None: if value in (None, ""): return None try: return Decimal(str(value)) except Exception: return None def _coerce_int(value: Any) -> int | None: if value in (None, ""): return None try: return int(value) except (TypeError, ValueError): return None def _coerce_bool(value: Any) -> bool | None: if value is None: return None if isinstance(value, bool): return value if isinstance(value, str): normalized = value.strip().lower() if normalized in {"true", "1", "yes", "y", "on"}: return True if normalized in {"false", "0", "no", "n", "off"}: return False return None return bool(value) def _coerce_bool_any(raw: Mapping[str, Any], keys: Sequence[str]) -> bool | None: found = False for key in keys: if key not in raw: continue found = True if _coerce_bool(raw.get(key)) is True: return True return False if found else None def _date_score_from_text(text: str) -> int: normalized = text.lower() best = 0 for year, month, day in re.findall(r"(20\d{2})[-_]?([01]\d)[-_]?([0-3]\d)", normalized): best = max(best, int(year) * 10000 + int(month) * 100 + int(day)) for year, month in re.findall(r"\b(20\d{2})[-_]?([01]\d)\b", normalized): best = max(best, int(year) * 10000 + int(month) * 100) for suffix in re.findall(r"(?:^|[-_])(\d{4})(?:$|[-_])", normalized): year = int(suffix[:2]) month = int(suffix[2:]) if 1 <= month <= 12 and 24 <= year <= 40: best = max(best, (2000 + year) * 10000 + month * 100) return best def _date_label_from_text(text: str) -> str | None: normalized = text.lower() matches = re.findall(r"(20\d{2})[-_]?([01]\d)[-_]?([0-3]\d)", normalized) if matches: year, month, day = max(matches) return f"{year}-{month}-{day}" matches = re.findall(r"\b(20\d{2})[-_]?([01]\d)\b", normalized) if matches: year, month = max(matches) return f"{year}-{month}" parsed: list[tuple[int, int]] = [] for suffix in re.findall(r"(?:^|[-_])(\d{4})(?:$|[-_])", normalized): year = int(suffix[:2]) month = int(suffix[2:]) if 1 <= month <= 12 and 24 <= year <= 40: parsed.append((2000 + year, month)) if parsed: year, month = max(parsed) return f"{year}-{month:02d}" return None def _filter_date_score(value: str | None) -> int | None: if value is None: return None score = _date_score_from_text(value) if score == 0: raise ValueError(f"Could not parse date filter: {value!r}.") return score def _created_score(model: LiteLLMRegistryModel) -> int: if model.created is not None: return model.created if model.created_at: try: parsed = datetime.fromisoformat(model.created_at.replace("Z", "+00:00")) except ValueError: return 0 if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) return int(parsed.timestamp()) return 0 def _version_score(name: str) -> int: score = 0 scale = 10**15 for number in re.findall(r"\d+", name)[:6]: score += min(int(number), 999) * scale scale //= 1000 return score def _recency_score(model: LiteLLMRegistryModel) -> int: name = model.model_key.lower() latest_score = 1 if "latest" in name else 0 return ( latest_score * 10**30 + max(_created_score(model), _date_score_from_text(name)) * 10**18 + _version_score(name) ) def _release_score(model: LiteLLMRegistryModel) -> int: release_date = model.release_date return max( _created_score(model), _date_score_from_text(model.model_key), _date_score_from_text(release_date or ""), )