"""Cost and capability comparisons over the model catalog.
Purpose:
Turn provider/LiteLLM model catalog rows into practical comparison tables:
cheapest models, coding-capable shortlists, per-call cost estimates, and
"how many calls of X equal one call of Y" ratios.
Design:
- Reuse :func:`ooai_llm.model_defaults.list_model_catalog` as the source of
truth so pricing and model availability can be refreshed independently.
- Keep estimates explicit about the assumed input/output token shape.
- Treat catalog pricing as planning data; observed provider usage metadata
remains the billing source of truth after a real call.
"""
from __future__ import annotations
from collections.abc import Iterable, Iterator, Sequence
from decimal import Decimal
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator
from .catalog import ListModelsConfig
from .model_defaults import (
ModelCapabilityName,
ModelCatalogSortName,
ModelDefaultCandidate,
ModelDefaultSource,
list_model_catalog,
)
from .providers import Provider
from .settings import AppSettings
from .types import ModelString
[docs]
CatalogModelStyle = Literal["langchain", "litellm", "bare"]
[docs]
ModelComparisonSortName = Literal[
"call_cost",
"cost",
"calls",
"calls_per_usd",
"provider",
"model",
"input_tokens",
"output_tokens",
"context",
]
[docs]
class ModelCallShape(BaseModel):
"""Token shape used to estimate one representative model call."""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
output_tokens: int = Field(default=2_000, ge=0)
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def total_tokens(self) -> int:
"""Return total input plus output tokens."""
return self.input_tokens + self.output_tokens
[docs]
class ModelCostEstimate(BaseModel):
"""One model's estimated cost for a representative call shape."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
release_date: str | None = None
[docs]
output_cost_per_1m_tokens: Decimal | None = None
[docs]
max_output_tokens: int | None = None
[docs]
context_window: int | None = None
[docs]
capabilities: list[str] = Field(default_factory=list)
[docs]
call_cost_usd: Decimal | None = None
[docs]
calls_per_usd: Decimal | None = None
[docs]
calls_per_budget: Decimal | None = None
[docs]
budget_usd: Decimal = Decimal("1")
@field_validator("model", mode="before")
@classmethod
def _coerce_model(cls, value: str | ModelString) -> ModelString:
"""Parse model strings into canonical typed model strings."""
return ModelString.parse(value).canonical()
[docs]
def model_name(self, *, style: CatalogModelStyle = "langchain") -> str:
"""Return the model string in the requested style."""
if style == "litellm":
return self.model.as_litellm()
if style == "bare":
return self.model.model_name
return self.model.as_langchain()
[docs]
class ModelCallEquivalent(BaseModel):
"""Cost ratio between a baseline model call and another model call."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
baseline_model: ModelString
[docs]
compared_model: ModelString
[docs]
baseline_call_cost_usd: Decimal
[docs]
compared_call_cost_usd: Decimal
[docs]
compared_calls_per_baseline_call: Decimal
@field_validator("baseline_model", "compared_model", mode="before")
@classmethod
def _coerce_model(cls, value: str | ModelString) -> ModelString:
"""Parse model strings into canonical typed model strings."""
return ModelString.parse(value).canonical()
[docs]
class ModelCostComparison(BaseModel):
"""Iterable cost comparison result for catalog models."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
shape: ModelCallShape = Field(default_factory=ModelCallShape)
[docs]
budget_usd: Decimal = Decimal("1")
[docs]
estimates: list[ModelCostEstimate] = Field(default_factory=list)
[docs]
notes: list[str] = Field(default_factory=list)
[docs]
def __iter__(self) -> Iterator[ModelCostEstimate]:
"""Iterate over cost estimates."""
return iter(self.estimates)
[docs]
def __len__(self) -> int:
"""Return the number of estimates."""
return len(self.estimates)
[docs]
def model_list(self, *, style: CatalogModelStyle = "langchain") -> list[str]:
"""Return compared models as an ordered list of strings."""
return [estimate.model_name(style=style) for estimate in self.estimates]
[docs]
def model_dict(self, *, style: CatalogModelStyle = "langchain") -> dict[str, str]:
"""Return compared models keyed by provider/model string."""
return {
estimate.model.as_langchain(): estimate.model_name(style=style)
for estimate in self.estimates
}
[docs]
def by_provider(self) -> dict[str, list[ModelCostEstimate]]:
"""Group estimates by provider."""
grouped: dict[str, list[ModelCostEstimate]] = {}
for estimate in self.estimates:
grouped.setdefault(estimate.provider.value, []).append(estimate)
return grouped
[docs]
def cheapest_per_provider(self) -> list[ModelCostEstimate]:
"""Return the cheapest known estimate for each provider."""
cheapest: dict[str, ModelCostEstimate] = {}
for estimate in self.estimates:
provider = estimate.provider.value
existing = cheapest.get(provider)
if existing is None or _estimate_sort_key(estimate) < _estimate_sort_key(existing):
cheapest[provider] = estimate
return sorted(cheapest.values(), key=_estimate_sort_key)
[docs]
def find_model(self, model: str | ModelString) -> ModelCostEstimate | None:
"""Find an estimate by LangChain, LiteLLM, bare, or raw model string."""
requested = _model_match_keys(model)
for estimate in self.estimates:
if requested.intersection(_model_match_keys(estimate.model)):
return estimate
return None
[docs]
def equivalents(self, baseline_model: str | ModelString) -> list[ModelCallEquivalent]:
"""Return ratios against a baseline model in this comparison.
The ratio answers: "how many calls of the compared model cost the same
as one call of the baseline model" for the configured call shape.
"""
baseline = self.find_model(baseline_model)
if baseline is None or baseline.call_cost_usd is None or baseline.call_cost_usd <= 0:
return []
equivalents: list[ModelCallEquivalent] = []
for estimate in self.estimates:
if estimate.call_cost_usd is None or estimate.call_cost_usd <= 0:
continue
equivalents.append(
ModelCallEquivalent(
baseline_model=baseline.model,
compared_model=estimate.model,
baseline_call_cost_usd=baseline.call_cost_usd,
compared_call_cost_usd=estimate.call_cost_usd,
compared_calls_per_baseline_call=baseline.call_cost_usd / estimate.call_cost_usd,
)
)
return equivalents
[docs]
def estimate_model_call_cost(
candidate: ModelDefaultCandidate,
shape: ModelCallShape | None = None,
) -> Decimal | None:
"""Estimate one call cost from catalog pricing and token shape.
Args:
candidate: Catalog model row with per-token input/output prices.
shape: Representative token shape. Defaults to 10k input and 2k output.
Returns:
Estimated USD cost, or ``None`` when required pricing is unknown.
"""
resolved_shape = shape or ModelCallShape()
total = Decimal("0")
if resolved_shape.input_tokens:
if candidate.input_cost_per_token is None:
return None
total += Decimal(resolved_shape.input_tokens) * candidate.input_cost_per_token
if resolved_shape.output_tokens:
if candidate.output_cost_per_token is None:
return None
total += Decimal(resolved_shape.output_tokens) * candidate.output_cost_per_token
return total
[docs]
def compare_model_candidates(
candidates: Sequence[ModelDefaultCandidate],
*,
shape: ModelCallShape | None = None,
budget_usd: Decimal = Decimal("1"),
sort_by: ModelComparisonSortName = "call_cost",
notes: Iterable[str] | None = None,
) -> ModelCostComparison:
"""Build a cost comparison from preloaded catalog candidates."""
resolved_shape = shape or ModelCallShape()
estimates = [
_estimate_from_candidate(
candidate,
shape=resolved_shape,
budget_usd=budget_usd,
)
for candidate in candidates
]
return ModelCostComparison(
shape=resolved_shape,
budget_usd=budget_usd,
estimates=_sort_estimates(estimates, sort_by=sort_by),
notes=[
"Costs are estimates from catalog pricing for the configured token shape.",
"Use provider response usage metadata as the source of truth for billing.",
*(notes or []),
],
)
[docs]
def compare_model_catalog(
settings: AppSettings | None = None,
*,
providers: Iterable[Provider | str] | None = None,
source: ModelDefaultSource = "auto",
config: ListModelsConfig | None = None,
include_non_chat: bool = False,
capabilities: Iterable[ModelCapabilityName] | None = None,
min_context_tokens: int | None = None,
min_output_tokens: int | None = None,
max_input_cost_per_1m: Decimal | None = None,
max_output_cost_per_1m: Decimal | None = None,
released_after: str | None = None,
released_before: str | None = None,
shape: ModelCallShape | None = None,
input_tokens: int | None = None,
output_tokens: int | None = None,
budget_usd: Decimal = Decimal("1"),
per_provider: bool = False,
sort_by: ModelComparisonSortName = "call_cost",
limit: int | None = None,
strict: bool = False,
) -> ModelCostComparison:
"""Compare model costs across a filtered model catalog.
Args:
settings: Base settings. Defaults to ``AppSettings()``.
providers: Providers to inspect. Defaults to every supported provider.
source: ``"provider"``, ``"litellm"``, or ``"auto"`` catalog source.
config: Optional provider-listing configuration.
include_non_chat: Include non-chat-like models.
capabilities: Required capability labels, such as ``"coding"``.
min_context_tokens: Optional minimum context/input-token window.
min_output_tokens: Optional minimum output-token limit.
max_input_cost_per_1m: Optional maximum input cost per one million tokens.
max_output_cost_per_1m: Optional maximum output cost per one million tokens.
released_after: Optional lower release-date bound.
released_before: Optional upper release-date bound.
shape: Representative input/output token shape.
input_tokens: Override ``shape.input_tokens``.
output_tokens: Override ``shape.output_tokens``.
budget_usd: Budget used for calls-per-budget estimates. Defaults to $1.
per_provider: Keep only the cheapest row for each provider.
sort_by: Sort mode for returned estimates. Defaults to call cost.
limit: Optional maximum number of returned estimates. ``0`` disables it.
strict: Raise on provider listing failure.
Returns:
Cost comparison sorted by the requested estimate field.
"""
resolved_shape = _resolve_shape(shape, input_tokens=input_tokens, output_tokens=output_tokens)
catalog = list_model_catalog(
settings,
providers=providers,
source=source,
config=config,
include_non_chat=include_non_chat,
capabilities=capabilities,
min_context_tokens=min_context_tokens,
min_output_tokens=min_output_tokens,
max_input_cost_per_1m=max_input_cost_per_1m,
max_output_cost_per_1m=max_output_cost_per_1m,
released_after=released_after,
released_before=released_before,
sort_by=_catalog_sort_for_comparison(sort_by),
strict=strict,
)
comparison = compare_model_candidates(
catalog.models,
shape=resolved_shape,
budget_usd=budget_usd,
sort_by=sort_by,
notes=catalog.notes,
)
estimates = comparison.cheapest_per_provider() if per_provider else list(comparison.estimates)
estimates = _sort_estimates(estimates, sort_by=sort_by)
if limit is not None and limit > 0:
estimates = estimates[:limit]
return comparison.model_copy(update={"estimates": estimates})
[docs]
def get_cheapest_models(
settings: AppSettings | None = None,
*,
providers: Iterable[Provider | str] | None = None,
source: ModelDefaultSource = "auto",
per_provider: bool = False,
limit: int | None = 20,
**kwargs: Any,
) -> ModelCostComparison:
"""Return the cheapest catalog-priced models for a token shape."""
return compare_model_catalog(
settings,
providers=providers,
source=source,
per_provider=per_provider,
limit=limit,
**kwargs,
)
[docs]
def get_coding_model_comparison(
settings: AppSettings | None = None,
*,
providers: Iterable[Provider | str] | None = None,
source: ModelDefaultSource = "auto",
limit: int | None = 20,
**kwargs: Any,
) -> ModelCostComparison:
"""Return a cost-ranked comparison of coding-oriented catalog models."""
capabilities = list(kwargs.pop("capabilities", []) or [])
if "coding" not in capabilities:
capabilities.append("coding")
return compare_model_catalog(
settings,
providers=providers,
source=source,
capabilities=capabilities,
limit=limit,
**kwargs,
)
def _resolve_shape(
shape: ModelCallShape | None,
*,
input_tokens: int | None,
output_tokens: int | None,
) -> ModelCallShape:
resolved = shape or ModelCallShape()
updates: dict[str, int] = {}
if input_tokens is not None:
updates["input_tokens"] = input_tokens
if output_tokens is not None:
updates["output_tokens"] = output_tokens
if not updates:
return resolved
return ModelCallShape(
input_tokens=updates.get("input_tokens", resolved.input_tokens),
output_tokens=updates.get("output_tokens", resolved.output_tokens),
)
def _estimate_from_candidate(
candidate: ModelDefaultCandidate,
*,
shape: ModelCallShape,
budget_usd: Decimal,
) -> ModelCostEstimate:
cost = estimate_model_call_cost(candidate, shape)
calls_per_usd = None if cost is None or cost <= 0 else Decimal("1") / cost
calls_per_budget = None if cost is None or cost <= 0 else budget_usd / cost
return ModelCostEstimate(
provider=candidate.provider,
model=candidate.model_string,
source=candidate.source,
release_date=candidate.release_date,
input_cost_per_1m_tokens=candidate.input_cost_per_1m_tokens,
output_cost_per_1m_tokens=candidate.output_cost_per_1m_tokens,
max_input_tokens=candidate.context_window,
max_output_tokens=candidate.max_output_tokens,
context_window=candidate.context_window,
capabilities=candidate.capability_labels,
call_cost_usd=cost,
calls_per_usd=calls_per_usd,
calls_per_budget=calls_per_budget,
budget_usd=budget_usd,
)
def _catalog_sort_for_comparison(sort_by: ModelComparisonSortName) -> ModelCatalogSortName:
if sort_by in {"input_tokens", "context"}:
return "input_tokens"
if sort_by == "output_tokens":
return "output_tokens"
if sort_by == "provider":
return "provider"
if sort_by == "model":
return "model"
return "cost"
def _estimate_sort_key(estimate: ModelCostEstimate) -> tuple[bool, Decimal, str, str]:
infinity = Decimal("Infinity")
return (
estimate.call_cost_usd is None,
estimate.call_cost_usd or infinity,
estimate.provider.value,
estimate.model.model_name,
)
def _sort_estimates(
estimates: Sequence[ModelCostEstimate],
*,
sort_by: ModelComparisonSortName,
) -> list[ModelCostEstimate]:
if sort_by in {"call_cost", "cost"}:
return sorted(estimates, key=_estimate_sort_key)
if sort_by in {"calls", "calls_per_usd"}:
return sorted(
estimates,
key=lambda estimate: (
estimate.calls_per_budget is None,
-(estimate.calls_per_budget or Decimal("0")),
estimate.provider.value,
estimate.model.model_name,
),
)
if sort_by == "provider":
return sorted(estimates, key=lambda estimate: (estimate.provider.value, estimate.model.model_name))
if sort_by == "model":
return sorted(estimates, key=lambda estimate: estimate.model.model_name)
if sort_by in {"input_tokens", "context"}:
return sorted(
estimates,
key=lambda estimate: (
estimate.context_window is None,
-(estimate.context_window or 0),
_estimate_sort_key(estimate),
),
)
if sort_by == "output_tokens":
return sorted(
estimates,
key=lambda estimate: (
estimate.max_output_tokens is None,
-(estimate.max_output_tokens or 0),
_estimate_sort_key(estimate),
),
)
return sorted(estimates, key=_estimate_sort_key)
def _model_match_keys(model: str | ModelString) -> set[str]:
parsed = ModelString.parse(model).canonical()
keys = {
str(parsed),
parsed.as_langchain(),
parsed.as_litellm(),
parsed.model_name,
}
return {key.lower() for key in keys if key}