Source code for ooai_llm.model_suites

"""Reusable model suites for comparisons and multi-model workflows.

Purpose:
    Build configurable shortlists of models that can be iterated, serialized,
    converted to profiles, or materialized as LangChain chat models/runtimes.
"""

from __future__ import annotations

from collections.abc import Iterable, Iterator, Mapping, Sequence
from decimal import Decimal
import re
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator

from .catalog import ListModelsConfig
from .model_defaults import (
    ModelCapabilityName,
    ModelCatalogSortName,
    ModelDefaultCandidate,
    ModelDefaultSource,
    list_model_catalog,
)
from .profiles import CacheKeyPolicy, ChatModelProfile, LLM
from .providers import Provider, normalize_provider_name
from .reasoning import ReasoningInput
from .settings import AppSettings, ModelPresetName
from .types import ModelString

[docs] ModelSuiteStyle = Literal["langchain", "litellm", "bare"]
[docs] ModelSuiteName = Literal["practical", "comparison", "testing", "all-presets"]
[docs] ModelSuiteRoleName = Literal[ "default", "latest", "cheap", "testing", "fast", "balanced", "reasoning", "coding", "vision", "custom", ]
[docs] DEFAULT_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("cheap", "balanced", "reasoning")
[docs] COMPARISON_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("cheap", "fast", "balanced", "reasoning", "coding")
[docs] TESTING_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("testing",)
[docs] ALL_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ( "default", "latest", "cheap", "testing", "fast", "balanced", "reasoning", "coding", "vision", )
[docs] class ModelSuiteEntry(BaseModel): """One named model option inside a reusable suite."""
[docs] model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs] key: str
[docs] model: ModelString
[docs] role: ModelSuiteRoleName = "custom"
[docs] label: str | None = None
[docs] description: str | None = None
[docs] capabilities: list[str] = Field(default_factory=list)
[docs] reasoning: ReasoningInput = None
[docs] profile_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs] tags: list[str] = Field(default_factory=list)
[docs] metadata: dict[str, Any] = Field(default_factory=dict)
[docs] cost_labels: dict[str, str] = Field(default_factory=dict)
[docs] cache: CacheKeyPolicy | bool | None = None
@field_validator("model", mode="before") @classmethod def _coerce_model(cls, value: str | ModelString) -> ModelString: """Parse model strings into typed model-string objects.""" return ModelString.parse(value).canonical() @field_validator("key") @classmethod def _validate_key(cls, value: str) -> str: """Normalize suite keys so they are safe as dict keys and run labels.""" text = value.strip() if not text: raise ValueError("Model suite entry key cannot be empty.") return _slug(text) @computed_field # type: ignore[prop-decorator] @property
[docs] def provider(self) -> Provider | None: """Return the model provider when known.""" return self.model.provider
[docs] def model_name(self, *, style: ModelSuiteStyle = "langchain") -> str: """Return the model string in the requested style.""" if style == "litellm": return self.model.as_litellm() if style == "bare": return self.model.model_name return self.model.as_langchain()
[docs] def to_profile(self, **overrides: Any) -> ChatModelProfile: """Convert this suite entry into a serializable chat-model profile.""" profile_kwargs = dict(self.profile_kwargs) profile_kwargs.update(overrides) metadata = dict(self.metadata) metadata.setdefault("ooai_model_suite_entry", self.key) if self.label: metadata.setdefault("ooai_model_suite_label", self.label) cost_labels = dict(self.cost_labels) cost_labels.setdefault("model_suite_role", self.role) tags = [*self.tags, self.role] return ChatModelProfile( id=profile_kwargs.pop("id", self.key), model=profile_kwargs.pop("model", self.model.as_langchain()), reasoning=profile_kwargs.pop("reasoning", self.reasoning), cache=profile_kwargs.pop("cache", self.cache), tags=profile_kwargs.pop("tags", tags), metadata=profile_kwargs.pop("metadata", metadata), cost_labels=profile_kwargs.pop("cost_labels", cost_labels), **profile_kwargs, )
[docs] def create_llm(self, *, settings: AppSettings | None = None, **overrides: Any) -> Any: """Create a LangChain chat model for this entry.""" return self.to_profile(**overrides).create_llm(settings=settings)
[docs] def create_runtime( self, *, settings: AppSettings | None = None, id: str | None = None, **overrides: Any, ) -> LLM: """Create an ``LLM`` runtime for this entry.""" return self.to_profile(**overrides).create_runtime(settings=settings, id=id or self.key)
[docs] class ModelSuite(BaseModel): """A named, iterable set of model options."""
[docs] model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs] name: str
[docs] description: str | None = None
[docs] entries: list[ModelSuiteEntry] = Field(default_factory=list)
[docs] notes: list[str] = Field(default_factory=list)
[docs] def __iter__(self) -> Iterator[ModelSuiteEntry]: """Iterate over suite entries.""" return iter(self.entries)
[docs] def __len__(self) -> int: """Return the number of entries.""" return len(self.entries)
[docs] def __getitem__(self, key: int | str) -> ModelSuiteEntry: """Get an entry by index or key.""" if isinstance(key, int): return self.entries[key] for entry in self.entries: if entry.key == key: return entry raise KeyError(key)
@computed_field # type: ignore[prop-decorator] @property
[docs] def providers(self) -> list[str]: """Return providers represented in the suite.""" values: list[str] = [] for entry in self.entries: if entry.provider is not None and entry.provider.value not in values: values.append(entry.provider.value) return values
[docs] def model_list(self, *, style: ModelSuiteStyle = "langchain") -> list[str]: """Return suite models as an ordered list of strings.""" return [entry.model_name(style=style) for entry in self.entries]
[docs] def model_dict(self, *, style: ModelSuiteStyle = "langchain") -> dict[str, str]: """Return suite models keyed by entry key.""" return {entry.key: entry.model_name(style=style) for entry in self.entries}
[docs] def entry_dict(self) -> dict[str, ModelSuiteEntry]: """Return full suite entries keyed by entry key.""" return {entry.key: entry for entry in self.entries}
[docs] def by_provider(self) -> dict[str, list[ModelSuiteEntry]]: """Group entries by provider.""" grouped: dict[str, list[ModelSuiteEntry]] = {} for entry in self.entries: provider = entry.provider.value if entry.provider is not None else "unknown" grouped.setdefault(provider, []).append(entry) return grouped
[docs] def filter( self, *, providers: Iterable[Provider | str] | None = None, roles: Iterable[ModelSuiteRoleName | str] | None = None, capabilities: Iterable[str] | None = None, keys: Iterable[str] | None = None, ) -> ModelSuite: """Return a filtered copy of this suite.""" provider_set = _provider_value_set(providers) role_set = {str(role) for role in roles or []} capability_set = {str(capability) for capability in capabilities or []} key_set = {_slug(key) for key in keys or []} entries = [ entry for entry in self.entries if (not provider_set or (entry.provider is not None and entry.provider.value in provider_set)) and (not role_set or entry.role in role_set) and (not capability_set or capability_set.issubset(set(entry.capabilities))) and (not key_set or entry.key in key_set) ] return self.model_copy(update={"entries": entries})
[docs] def to_profiles(self, **overrides: Any) -> dict[str, ChatModelProfile]: """Convert every entry into a chat-model profile.""" return {entry.key: entry.to_profile(**overrides) for entry in self.entries}
[docs] def create_llms(self, *, settings: AppSettings | None = None, **overrides: Any) -> dict[str, Any]: """Create LangChain chat models for every suite entry.""" return { entry.key: entry.create_llm(settings=settings, **overrides) for entry in self.entries }
[docs] def create_runtimes(self, *, settings: AppSettings | None = None, **overrides: Any) -> dict[str, LLM]: """Create ``LLM`` runtimes for every suite entry.""" return { entry.key: entry.create_runtime(settings=settings, **overrides) for entry in self.entries }
[docs] def list_model_suite_names() -> list[str]: """Return built-in suite names.""" return ["practical", "comparison", "testing", "all-presets"]
[docs] def get_model_suite( name: ModelSuiteName | str = "practical", *, settings: AppSettings | None = None, providers: Iterable[Provider | str] | None = None, presets: Iterable[ModelPresetName] | None = None, temperature: float | None = None, parallel_tool_calls: bool | None = None, ) -> ModelSuite: """Build a named suite from configured provider presets.""" suite_name = str(name).strip().lower().replace("_", "-") if presets is None: presets = _suite_presets(suite_name) return model_suite_from_presets( settings=settings, providers=providers, presets=presets, name=suite_name, temperature=temperature, parallel_tool_calls=parallel_tool_calls, )
[docs] def model_suite_from_presets( *, settings: AppSettings | None = None, providers: Iterable[Provider | str] | None = None, presets: Iterable[ModelPresetName] = DEFAULT_MODEL_SUITE_PRESETS, name: str = "provider-presets", temperature: float | None = None, parallel_tool_calls: bool | None = None, ) -> ModelSuite: """Build a suite from ``AppSettings`` provider presets.""" resolved_settings = settings or AppSettings() provider_list = _normalize_providers(providers) preset_list = list(presets) entries: list[ModelSuiteEntry] = [] seen_models: set[tuple[str, str]] = set() for provider in provider_list: provider_presets = resolved_settings.llm.defaults_by_provider.get(provider) for preset in preset_list: model = ModelString.parse(provider_presets.get(preset)).canonical() dedupe_key = (provider.value, model.as_langchain()) if dedupe_key in seen_models: continue seen_models.add(dedupe_key) profile_kwargs: dict[str, Any] = {} if temperature is not None: profile_kwargs["temperature"] = temperature if parallel_tool_calls is not None: profile_kwargs["parallel_tool_calls"] = parallel_tool_calls entries.append( ModelSuiteEntry( key=f"{provider.value}-{preset}", model=model, role=preset, label=f"{provider.value} {preset}", description=f"{preset} preset for {provider.value}", capabilities=_preset_capabilities(preset), profile_kwargs=profile_kwargs, tags=["model-suite", name, provider.value, preset], metadata={ "ooai_model_suite": name, "ooai_model_suite_provider": provider.value, "ooai_model_suite_preset": preset, }, cost_labels={ "model_suite": name, "provider": provider.value, "preset": preset, }, ) ) return ModelSuite( name=name, description="Model suite built from configured provider presets.", entries=entries, )
[docs] def model_suite_from_catalog( *, settings: AppSettings | None = None, providers: Iterable[Provider | str] | None = None, source: ModelDefaultSource = "auto", config: ListModelsConfig | None = None, include_non_chat: bool = False, capabilities: Iterable[ModelCapabilityName] | None = None, min_context_tokens: int | None = None, min_output_tokens: int | None = None, max_input_cost_per_1m: Decimal | None = None, max_output_cost_per_1m: Decimal | None = None, released_after: str | None = None, released_before: str | None = None, sort_by: ModelCatalogSortName = "recency", limit: int | None = None, name: str = "catalog", temperature: float | None = None, parallel_tool_calls: bool | None = None, strict: bool = False, ) -> ModelSuite: """Build a suite from filtered catalog rows.""" result = list_model_catalog( settings=settings, providers=providers, source=source, config=config, include_non_chat=include_non_chat, capabilities=capabilities, min_context_tokens=min_context_tokens, min_output_tokens=min_output_tokens, max_input_cost_per_1m=max_input_cost_per_1m, max_output_cost_per_1m=max_output_cost_per_1m, released_after=released_after, released_before=released_before, sort_by=sort_by, strict=strict, ) candidates = result.models if limit is None or limit <= 0 else result.models[:limit] entries = [ _entry_from_candidate( candidate, name=name, temperature=temperature, parallel_tool_calls=parallel_tool_calls, ) for candidate in candidates ] return ModelSuite( name=name, description="Model suite built from filtered model catalog rows.", entries=entries, notes=result.notes, )
def _entry_from_candidate( candidate: ModelDefaultCandidate, *, name: str, temperature: float | None, parallel_tool_calls: bool | None, ) -> ModelSuiteEntry: profile_kwargs: dict[str, Any] = {} if temperature is not None: profile_kwargs["temperature"] = temperature if parallel_tool_calls is not None: profile_kwargs["parallel_tool_calls"] = parallel_tool_calls role = _role_from_candidate(candidate) capabilities = candidate.capability_labels return ModelSuiteEntry( key=f"{candidate.provider.value}-{candidate.model_id}", model=candidate.model_string, role=role, label=candidate.display_name or candidate.model_id, description=f"{candidate.provider.value} catalog model from {candidate.source}", capabilities=capabilities, profile_kwargs=profile_kwargs, tags=["model-suite", name, candidate.provider.value, role, *capabilities], metadata={ "ooai_model_suite": name, "ooai_model_suite_source": candidate.source, "ooai_model_suite_release_date": candidate.release_date, "ooai_model_suite_input_cost_per_1m": ( str(candidate.input_cost_per_1m_tokens) if candidate.input_cost_per_1m_tokens is not None else None ), "ooai_model_suite_output_cost_per_1m": ( str(candidate.output_cost_per_1m_tokens) if candidate.output_cost_per_1m_tokens is not None else None ), }, cost_labels={ "model_suite": name, "provider": candidate.provider.value, "role": role, "source": candidate.source, }, ) def _suite_presets(name: str) -> tuple[ModelPresetName, ...]: if name == "practical": return DEFAULT_MODEL_SUITE_PRESETS if name == "comparison": return COMPARISON_MODEL_SUITE_PRESETS if name == "testing": return TESTING_MODEL_SUITE_PRESETS if name == "all-presets": return ALL_MODEL_SUITE_PRESETS raise ValueError( f"Unknown model suite {name!r}. Supported suites: {', '.join(list_model_suite_names())}." ) def _role_from_candidate(candidate: ModelDefaultCandidate) -> ModelSuiteRoleName: if "cheap" in candidate.capability_labels: return "cheap" if "coding" in candidate.capability_labels: return "coding" if "reasoning" in candidate.capability_labels: return "reasoning" if "vision" in candidate.capability_labels: return "vision" return "custom" def _preset_capabilities(preset: ModelPresetName) -> list[str]: if preset == "cheap": return ["chat", "cheap"] if preset == "testing": return ["chat", "testing", "cheap"] if preset == "fast": return ["chat", "fast"] if preset == "balanced": return ["chat", "balanced"] if preset == "reasoning": return ["chat", "reasoning"] if preset == "coding": return ["chat", "coding"] if preset == "vision": return ["chat", "vision"] return ["chat", preset] def _normalize_providers(providers: Iterable[Provider | str] | None) -> list[Provider]: if providers is None: return [Provider.OPENAI, Provider.ANTHROPIC, Provider.GOOGLE_GENAI, Provider.DEEPSEEK, Provider.MISTRAL] normalized: list[Provider] = [] for provider in providers: value = normalize_provider_name(provider) if value is None: raise ValueError(f"Unsupported provider: {provider!r}.") if value not in normalized: normalized.append(value) return normalized def _provider_value_set(providers: Iterable[Provider | str] | None) -> set[str]: if providers is None: return set() return {provider.value for provider in _normalize_providers(providers)} def _slug(value: str) -> str: text = re.sub(r"[^A-Za-z0-9_.-]+", "-", value.strip().lower()) text = re.sub(r"-+", "-", text).strip("-") if not text: raise ValueError("Slug cannot be empty.") return text