"""Reusable model suites for comparisons and multi-model workflows.
Purpose:
Build configurable shortlists of models that can be iterated, serialized,
converted to profiles, or materialized as LangChain chat models/runtimes.
"""
from __future__ import annotations
from collections.abc import Iterable, Iterator, Mapping, Sequence
from decimal import Decimal
import re
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator
from .catalog import ListModelsConfig
from .model_defaults import (
ModelCapabilityName,
ModelCatalogSortName,
ModelDefaultCandidate,
ModelDefaultSource,
list_model_catalog,
)
from .profiles import CacheKeyPolicy, ChatModelProfile, LLM
from .providers import Provider, normalize_provider_name
from .reasoning import ReasoningInput
from .settings import AppSettings, ModelPresetName
from .types import ModelString
[docs]
ModelSuiteStyle = Literal["langchain", "litellm", "bare"]
[docs]
ModelSuiteName = Literal["practical", "comparison", "testing", "all-presets"]
[docs]
ModelSuiteRoleName = Literal[
"default",
"latest",
"cheap",
"testing",
"fast",
"balanced",
"reasoning",
"coding",
"vision",
"custom",
]
[docs]
DEFAULT_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("cheap", "balanced", "reasoning")
[docs]
COMPARISON_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("cheap", "fast", "balanced", "reasoning", "coding")
[docs]
TESTING_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = ("testing",)
[docs]
ALL_MODEL_SUITE_PRESETS: tuple[ModelPresetName, ...] = (
"default",
"latest",
"cheap",
"testing",
"fast",
"balanced",
"reasoning",
"coding",
"vision",
)
[docs]
class ModelSuiteEntry(BaseModel):
"""One named model option inside a reusable suite."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
model: ModelString
[docs]
role: ModelSuiteRoleName = "custom"
[docs]
label: str | None = None
[docs]
description: str | None = None
[docs]
capabilities: list[str] = Field(default_factory=list)
[docs]
reasoning: ReasoningInput = None
[docs]
profile_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs]
tags: list[str] = Field(default_factory=list)
[docs]
metadata: dict[str, Any] = Field(default_factory=dict)
[docs]
cost_labels: dict[str, str] = Field(default_factory=dict)
[docs]
cache: CacheKeyPolicy | bool | None = None
@field_validator("model", mode="before")
@classmethod
def _coerce_model(cls, value: str | ModelString) -> ModelString:
"""Parse model strings into typed model-string objects."""
return ModelString.parse(value).canonical()
@field_validator("key")
@classmethod
def _validate_key(cls, value: str) -> str:
"""Normalize suite keys so they are safe as dict keys and run labels."""
text = value.strip()
if not text:
raise ValueError("Model suite entry key cannot be empty.")
return _slug(text)
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def provider(self) -> Provider | None:
"""Return the model provider when known."""
return self.model.provider
[docs]
def model_name(self, *, style: ModelSuiteStyle = "langchain") -> str:
"""Return the model string in the requested style."""
if style == "litellm":
return self.model.as_litellm()
if style == "bare":
return self.model.model_name
return self.model.as_langchain()
[docs]
def to_profile(self, **overrides: Any) -> ChatModelProfile:
"""Convert this suite entry into a serializable chat-model profile."""
profile_kwargs = dict(self.profile_kwargs)
profile_kwargs.update(overrides)
metadata = dict(self.metadata)
metadata.setdefault("ooai_model_suite_entry", self.key)
if self.label:
metadata.setdefault("ooai_model_suite_label", self.label)
cost_labels = dict(self.cost_labels)
cost_labels.setdefault("model_suite_role", self.role)
tags = [*self.tags, self.role]
return ChatModelProfile(
id=profile_kwargs.pop("id", self.key),
model=profile_kwargs.pop("model", self.model.as_langchain()),
reasoning=profile_kwargs.pop("reasoning", self.reasoning),
cache=profile_kwargs.pop("cache", self.cache),
tags=profile_kwargs.pop("tags", tags),
metadata=profile_kwargs.pop("metadata", metadata),
cost_labels=profile_kwargs.pop("cost_labels", cost_labels),
**profile_kwargs,
)
[docs]
def create_llm(self, *, settings: AppSettings | None = None, **overrides: Any) -> Any:
"""Create a LangChain chat model for this entry."""
return self.to_profile(**overrides).create_llm(settings=settings)
[docs]
def create_runtime(
self,
*,
settings: AppSettings | None = None,
id: str | None = None,
**overrides: Any,
) -> LLM:
"""Create an ``LLM`` runtime for this entry."""
return self.to_profile(**overrides).create_runtime(settings=settings, id=id or self.key)
[docs]
class ModelSuite(BaseModel):
"""A named, iterable set of model options."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
description: str | None = None
[docs]
entries: list[ModelSuiteEntry] = Field(default_factory=list)
[docs]
notes: list[str] = Field(default_factory=list)
[docs]
def __iter__(self) -> Iterator[ModelSuiteEntry]:
"""Iterate over suite entries."""
return iter(self.entries)
[docs]
def __len__(self) -> int:
"""Return the number of entries."""
return len(self.entries)
[docs]
def __getitem__(self, key: int | str) -> ModelSuiteEntry:
"""Get an entry by index or key."""
if isinstance(key, int):
return self.entries[key]
for entry in self.entries:
if entry.key == key:
return entry
raise KeyError(key)
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def providers(self) -> list[str]:
"""Return providers represented in the suite."""
values: list[str] = []
for entry in self.entries:
if entry.provider is not None and entry.provider.value not in values:
values.append(entry.provider.value)
return values
[docs]
def model_list(self, *, style: ModelSuiteStyle = "langchain") -> list[str]:
"""Return suite models as an ordered list of strings."""
return [entry.model_name(style=style) for entry in self.entries]
[docs]
def model_dict(self, *, style: ModelSuiteStyle = "langchain") -> dict[str, str]:
"""Return suite models keyed by entry key."""
return {entry.key: entry.model_name(style=style) for entry in self.entries}
[docs]
def entry_dict(self) -> dict[str, ModelSuiteEntry]:
"""Return full suite entries keyed by entry key."""
return {entry.key: entry for entry in self.entries}
[docs]
def by_provider(self) -> dict[str, list[ModelSuiteEntry]]:
"""Group entries by provider."""
grouped: dict[str, list[ModelSuiteEntry]] = {}
for entry in self.entries:
provider = entry.provider.value if entry.provider is not None else "unknown"
grouped.setdefault(provider, []).append(entry)
return grouped
[docs]
def filter(
self,
*,
providers: Iterable[Provider | str] | None = None,
roles: Iterable[ModelSuiteRoleName | str] | None = None,
capabilities: Iterable[str] | None = None,
keys: Iterable[str] | None = None,
) -> ModelSuite:
"""Return a filtered copy of this suite."""
provider_set = _provider_value_set(providers)
role_set = {str(role) for role in roles or []}
capability_set = {str(capability) for capability in capabilities or []}
key_set = {_slug(key) for key in keys or []}
entries = [
entry
for entry in self.entries
if (not provider_set or (entry.provider is not None and entry.provider.value in provider_set))
and (not role_set or entry.role in role_set)
and (not capability_set or capability_set.issubset(set(entry.capabilities)))
and (not key_set or entry.key in key_set)
]
return self.model_copy(update={"entries": entries})
[docs]
def to_profiles(self, **overrides: Any) -> dict[str, ChatModelProfile]:
"""Convert every entry into a chat-model profile."""
return {entry.key: entry.to_profile(**overrides) for entry in self.entries}
[docs]
def create_llms(self, *, settings: AppSettings | None = None, **overrides: Any) -> dict[str, Any]:
"""Create LangChain chat models for every suite entry."""
return {
entry.key: entry.create_llm(settings=settings, **overrides)
for entry in self.entries
}
[docs]
def create_runtimes(self, *, settings: AppSettings | None = None, **overrides: Any) -> dict[str, LLM]:
"""Create ``LLM`` runtimes for every suite entry."""
return {
entry.key: entry.create_runtime(settings=settings, **overrides)
for entry in self.entries
}
[docs]
def list_model_suite_names() -> list[str]:
"""Return built-in suite names."""
return ["practical", "comparison", "testing", "all-presets"]
[docs]
def get_model_suite(
name: ModelSuiteName | str = "practical",
*,
settings: AppSettings | None = None,
providers: Iterable[Provider | str] | None = None,
presets: Iterable[ModelPresetName] | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
) -> ModelSuite:
"""Build a named suite from configured provider presets."""
suite_name = str(name).strip().lower().replace("_", "-")
if presets is None:
presets = _suite_presets(suite_name)
return model_suite_from_presets(
settings=settings,
providers=providers,
presets=presets,
name=suite_name,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
)
[docs]
def model_suite_from_presets(
*,
settings: AppSettings | None = None,
providers: Iterable[Provider | str] | None = None,
presets: Iterable[ModelPresetName] = DEFAULT_MODEL_SUITE_PRESETS,
name: str = "provider-presets",
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
) -> ModelSuite:
"""Build a suite from ``AppSettings`` provider presets."""
resolved_settings = settings or AppSettings()
provider_list = _normalize_providers(providers)
preset_list = list(presets)
entries: list[ModelSuiteEntry] = []
seen_models: set[tuple[str, str]] = set()
for provider in provider_list:
provider_presets = resolved_settings.llm.defaults_by_provider.get(provider)
for preset in preset_list:
model = ModelString.parse(provider_presets.get(preset)).canonical()
dedupe_key = (provider.value, model.as_langchain())
if dedupe_key in seen_models:
continue
seen_models.add(dedupe_key)
profile_kwargs: dict[str, Any] = {}
if temperature is not None:
profile_kwargs["temperature"] = temperature
if parallel_tool_calls is not None:
profile_kwargs["parallel_tool_calls"] = parallel_tool_calls
entries.append(
ModelSuiteEntry(
key=f"{provider.value}-{preset}",
model=model,
role=preset,
label=f"{provider.value} {preset}",
description=f"{preset} preset for {provider.value}",
capabilities=_preset_capabilities(preset),
profile_kwargs=profile_kwargs,
tags=["model-suite", name, provider.value, preset],
metadata={
"ooai_model_suite": name,
"ooai_model_suite_provider": provider.value,
"ooai_model_suite_preset": preset,
},
cost_labels={
"model_suite": name,
"provider": provider.value,
"preset": preset,
},
)
)
return ModelSuite(
name=name,
description="Model suite built from configured provider presets.",
entries=entries,
)
[docs]
def model_suite_from_catalog(
*,
settings: AppSettings | None = None,
providers: Iterable[Provider | str] | None = None,
source: ModelDefaultSource = "auto",
config: ListModelsConfig | None = None,
include_non_chat: bool = False,
capabilities: Iterable[ModelCapabilityName] | None = None,
min_context_tokens: int | None = None,
min_output_tokens: int | None = None,
max_input_cost_per_1m: Decimal | None = None,
max_output_cost_per_1m: Decimal | None = None,
released_after: str | None = None,
released_before: str | None = None,
sort_by: ModelCatalogSortName = "recency",
limit: int | None = None,
name: str = "catalog",
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
strict: bool = False,
) -> ModelSuite:
"""Build a suite from filtered catalog rows."""
result = list_model_catalog(
settings=settings,
providers=providers,
source=source,
config=config,
include_non_chat=include_non_chat,
capabilities=capabilities,
min_context_tokens=min_context_tokens,
min_output_tokens=min_output_tokens,
max_input_cost_per_1m=max_input_cost_per_1m,
max_output_cost_per_1m=max_output_cost_per_1m,
released_after=released_after,
released_before=released_before,
sort_by=sort_by,
strict=strict,
)
candidates = result.models if limit is None or limit <= 0 else result.models[:limit]
entries = [
_entry_from_candidate(
candidate,
name=name,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
)
for candidate in candidates
]
return ModelSuite(
name=name,
description="Model suite built from filtered model catalog rows.",
entries=entries,
notes=result.notes,
)
def _entry_from_candidate(
candidate: ModelDefaultCandidate,
*,
name: str,
temperature: float | None,
parallel_tool_calls: bool | None,
) -> ModelSuiteEntry:
profile_kwargs: dict[str, Any] = {}
if temperature is not None:
profile_kwargs["temperature"] = temperature
if parallel_tool_calls is not None:
profile_kwargs["parallel_tool_calls"] = parallel_tool_calls
role = _role_from_candidate(candidate)
capabilities = candidate.capability_labels
return ModelSuiteEntry(
key=f"{candidate.provider.value}-{candidate.model_id}",
model=candidate.model_string,
role=role,
label=candidate.display_name or candidate.model_id,
description=f"{candidate.provider.value} catalog model from {candidate.source}",
capabilities=capabilities,
profile_kwargs=profile_kwargs,
tags=["model-suite", name, candidate.provider.value, role, *capabilities],
metadata={
"ooai_model_suite": name,
"ooai_model_suite_source": candidate.source,
"ooai_model_suite_release_date": candidate.release_date,
"ooai_model_suite_input_cost_per_1m": (
str(candidate.input_cost_per_1m_tokens) if candidate.input_cost_per_1m_tokens is not None else None
),
"ooai_model_suite_output_cost_per_1m": (
str(candidate.output_cost_per_1m_tokens) if candidate.output_cost_per_1m_tokens is not None else None
),
},
cost_labels={
"model_suite": name,
"provider": candidate.provider.value,
"role": role,
"source": candidate.source,
},
)
def _suite_presets(name: str) -> tuple[ModelPresetName, ...]:
if name == "practical":
return DEFAULT_MODEL_SUITE_PRESETS
if name == "comparison":
return COMPARISON_MODEL_SUITE_PRESETS
if name == "testing":
return TESTING_MODEL_SUITE_PRESETS
if name == "all-presets":
return ALL_MODEL_SUITE_PRESETS
raise ValueError(
f"Unknown model suite {name!r}. Supported suites: {', '.join(list_model_suite_names())}."
)
def _role_from_candidate(candidate: ModelDefaultCandidate) -> ModelSuiteRoleName:
if "cheap" in candidate.capability_labels:
return "cheap"
if "coding" in candidate.capability_labels:
return "coding"
if "reasoning" in candidate.capability_labels:
return "reasoning"
if "vision" in candidate.capability_labels:
return "vision"
return "custom"
def _preset_capabilities(preset: ModelPresetName) -> list[str]:
if preset == "cheap":
return ["chat", "cheap"]
if preset == "testing":
return ["chat", "testing", "cheap"]
if preset == "fast":
return ["chat", "fast"]
if preset == "balanced":
return ["chat", "balanced"]
if preset == "reasoning":
return ["chat", "reasoning"]
if preset == "coding":
return ["chat", "coding"]
if preset == "vision":
return ["chat", "vision"]
return ["chat", preset]
def _normalize_providers(providers: Iterable[Provider | str] | None) -> list[Provider]:
if providers is None:
return [Provider.OPENAI, Provider.ANTHROPIC, Provider.GOOGLE_GENAI, Provider.DEEPSEEK, Provider.MISTRAL]
normalized: list[Provider] = []
for provider in providers:
value = normalize_provider_name(provider)
if value is None:
raise ValueError(f"Unsupported provider: {provider!r}.")
if value not in normalized:
normalized.append(value)
return normalized
def _provider_value_set(providers: Iterable[Provider | str] | None) -> set[str]:
if providers is None:
return set()
return {provider.value for provider in _normalize_providers(providers)}
def _slug(value: str) -> str:
text = re.sub(r"[^A-Za-z0-9_.-]+", "-", value.strip().lower())
text = re.sub(r"-+", "-", text).strip("-")
if not text:
raise ValueError("Slug cannot be empty.")
return text