"""Serializable chat-model profiles and runtime wrapper.
Purpose:
Provide a durable Pydantic profile that can be checked into configuration
and used to create LangChain chat models through the existing factory
surface. The ``LLM`` runtime owns one profile, lazily builds the runnable,
and records observed usage/cost metadata after calls.
"""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from decimal import Decimal
import hashlib
import json
from typing import Any, Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, Field
from .cache import build_namespaced_cache
from .callbacks import (
BudgetPolicy,
UsageEvent,
UsageRecorder,
record_langchain_response_usage,
)
from .factory import create_llm as factory_create_llm
from .factory import create_llm_bundle, resolve_factory_settings, resolve_model_string
from .logging import get_logger, log_event, logging_context
from .metadata import CreatedLLMBundle, ModelInfo, get_model_info
from .providers import Provider, normalize_provider_name
from .reasoning import ReasoningConfig, ReasoningInput, ReasoningResolution, build_reasoning_resolution
from .settings import AppSettings, ModelAliasName, ModelPresetName
from .types import ModelString
[docs]
CacheBackendName = Literal["sqlite", "memory", "sqlalchemy", "redis", "upstash_redis"]
[docs]
logger = get_logger(__name__)
[docs]
class CacheKeyPolicy(BaseModel):
"""Serializable cache policy for a chat-model profile."""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
namespace: str = "default"
[docs]
backend: CacheBackendName | None = None
[docs]
class ChatModelProfileResolution(BaseModel):
"""Resolved model/profile metadata without constructing a chat model."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
profile_id: str | None = None
[docs]
reasoning: ReasoningResolution | None = None
[docs]
cache_namespace: str | None = None
[docs]
cache_key: str | None = None
[docs]
class ChatModelProfile(BaseModel):
"""Serializable configuration for creating a LangChain chat model."""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
model: str | None = None
[docs]
alias: ModelAliasName | None = None
[docs]
provider: str | Provider | None = None
[docs]
preset: ModelPresetName = "default"
[docs]
temperature: float | None = None
[docs]
max_tokens: int | None = Field(default=None, ge=1)
[docs]
top_p: float | None = Field(default=None, ge=0, le=1)
[docs]
frequency_penalty: float | None = None
[docs]
presence_penalty: float | None = None
[docs]
seed: int | None = None
[docs]
stop: str | list[str] | None = None
[docs]
max_retries: int | None = Field(default=None, ge=0)
[docs]
timeout: float | None = Field(default=None, ge=0)
[docs]
streaming: bool | None = None
[docs]
configurable_fields: str | list[str] | None = None
[docs]
config_prefix: str | None = None
[docs]
model_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs]
constructor_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs]
reasoning: ReasoningInput = None
[docs]
auto_refresh_models: bool | None = None
[docs]
force_model_refresh: bool = False
[docs]
billing_model_name: str | None = None
[docs]
cache: CacheKeyPolicy | bool | None = None
[docs]
run_name: str | None = None
[docs]
cost_labels: dict[str, str] = Field(default_factory=dict)
[docs]
def to_json(self, *, indent: int = 2) -> str:
"""Serialize the profile to deterministic JSON."""
return self.model_dump_json(indent=indent, exclude_none=True)
@classmethod
[docs]
def from_json(cls, text: str) -> ChatModelProfile:
"""Deserialize a profile from JSON text."""
return cls.model_validate_json(text)
[docs]
def create_llm(
self,
*,
settings: AppSettings | None = None,
**kwargs: Any,
) -> Any:
"""Create a LangChain chat model from this profile."""
resolved_settings = settings or AppSettings()
return factory_create_llm(
model=self.model,
settings=resolved_settings,
alias=self.alias,
provider=self.provider,
preset=self.preset,
cache=self._cache_argument(resolved_settings),
reasoning=self.reasoning,
auto_refresh_models=self.auto_refresh_models,
force_model_refresh=self.force_model_refresh,
configurable_fields=self.configurable_fields,
config_prefix=self.config_prefix,
**self._constructor_kwargs(kwargs),
)
[docs]
def create_bundle(
self,
*,
settings: AppSettings | None = None,
messages: Any = None,
tools: Sequence[Any] | None = None,
**kwargs: Any,
) -> CreatedLLMBundle:
"""Create a chat model plus resolved metadata from this profile."""
resolved_settings = settings or AppSettings()
return create_llm_bundle(
model=self.model,
settings=resolved_settings,
alias=self.alias,
provider=self.provider,
preset=self.preset,
cache=self._cache_argument(resolved_settings),
reasoning=self.reasoning,
auto_refresh_models=self.auto_refresh_models,
force_model_refresh=self.force_model_refresh,
billing_model_name=self.billing_model_name,
messages=messages,
tools=tools,
configurable_fields=self.configurable_fields,
config_prefix=self.config_prefix,
**self._constructor_kwargs(kwargs),
)
[docs]
def create_runtime(
self,
*,
settings: AppSettings | None = None,
recorder: UsageRecorder | None = None,
budget: BudgetPolicy | None = None,
id: str | None = None,
uuid: str | UUID | None = None,
) -> LLM:
"""Create an ``LLM`` runtime from this profile."""
return LLM(profile=self, settings=settings, recorder=recorder, budget=budget, id=id, uuid=uuid)
[docs]
def create_dspy_lm(
self,
*,
settings: AppSettings | None = None,
**kwargs: Any,
) -> Any:
"""Create a DSPy ``LM`` from this profile.
DSPy is an optional dependency and is imported lazily. Install
``ooai-llm[dspy]`` before using this helper.
"""
from .dspy import create_dspy_lm
return create_dspy_lm(self, settings=settings, **kwargs)
[docs]
def resolve(
self,
*,
settings: AppSettings | None = None,
auto_refresh_models: bool | None = None,
force_model_refresh: bool | None = None,
) -> ChatModelProfileResolution:
"""Resolve model, metadata, reasoning, and cache key material."""
resolved_settings = resolve_factory_settings(
settings,
auto_refresh_models=self.auto_refresh_models if auto_refresh_models is None else auto_refresh_models,
force_model_refresh=self.force_model_refresh if force_model_refresh is None else force_model_refresh,
)
resolved_model = resolve_model_string(
settings=resolved_settings,
model=self.model,
alias=self.alias,
provider=self.provider,
preset=self.preset,
)
metadata = get_model_info(
model=resolved_model,
settings=resolved_settings,
provider=self.provider,
billing_model_name=self.billing_model_name,
)
reasoning = build_reasoning_resolution(
model=resolved_model,
provider=self.provider,
reasoning=self.reasoning,
)
cache_policy = self.cache if isinstance(self.cache, CacheKeyPolicy) else None
return ChatModelProfileResolution(
model=resolved_model,
profile_id=self.id,
metadata=metadata,
reasoning=reasoning,
cache_namespace=cache_policy.namespace if cache_policy and cache_policy.enabled else None,
cache_key=self.cache_key() if cache_policy and cache_policy.enabled else None,
)
[docs]
def cache_key(self) -> str:
"""Return the explicit or stable derived cache key for this profile."""
if isinstance(self.cache, CacheKeyPolicy) and self.cache.key:
return self.cache.key
payload = self.model_dump(mode="json", exclude={"cache"}, exclude_none=True)
if isinstance(self.cache, CacheKeyPolicy):
payload["cache_namespace"] = self.cache.namespace
payload["cache_backend"] = self.cache.backend
digest = hashlib.sha256(json.dumps(payload, sort_keys=True, default=str).encode("utf-8")).hexdigest()
return digest[:16]
def _constructor_kwargs(self, overrides: Mapping[str, Any] | None = None) -> dict[str, Any]:
"""Build constructor kwargs, with top-level profile fields taking precedence."""
kwargs = dict(self.constructor_kwargs)
if self.model_kwargs:
existing_model_kwargs = kwargs.get("model_kwargs")
merged_model_kwargs = dict(existing_model_kwargs) if isinstance(existing_model_kwargs, Mapping) else {}
merged_model_kwargs.update(self.model_kwargs)
kwargs["model_kwargs"] = merged_model_kwargs
if self.parallel_tool_calls is not None:
existing_model_kwargs = kwargs.get("model_kwargs")
merged_model_kwargs = dict(existing_model_kwargs) if isinstance(existing_model_kwargs, Mapping) else {}
merged_model_kwargs["parallel_tool_calls"] = self.parallel_tool_calls
kwargs["model_kwargs"] = merged_model_kwargs
if self.temperature is not None:
kwargs["temperature"] = self.temperature
if self.max_tokens is not None:
kwargs["max_tokens"] = self.max_tokens
if self.top_p is not None:
kwargs["top_p"] = self.top_p
if self.frequency_penalty is not None:
kwargs["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
kwargs["presence_penalty"] = self.presence_penalty
if self.seed is not None:
kwargs["seed"] = self.seed
if self.stop is not None:
kwargs["stop"] = self.stop
if self.max_retries is not None:
kwargs["max_retries"] = self.max_retries
if self.timeout is not None:
kwargs["timeout"] = self.timeout
if self.streaming is not None:
kwargs["streaming"] = self.streaming
kwargs.update(dict(overrides or {}))
return kwargs
def _cache_argument(self, settings: AppSettings) -> Any:
"""Return the cache argument for factory calls."""
if self.cache is None or isinstance(self.cache, bool):
return self.cache
if not self.cache.enabled:
return False
cache_settings = settings
if self.cache.backend is not None:
llm_cache = settings.llm.cache.model_copy(update={"backend": self.cache.backend, "enabled": True})
llm_settings = settings.llm.model_copy(update={"cache": llm_cache})
cache_settings = settings.model_copy(update={"llm": llm_settings})
return build_namespaced_cache(
cache_settings,
namespace=self.cache.namespace,
profile_key=self.cache_key(),
)
[docs]
class LLM:
"""Stateful runtime wrapper around a serializable ``ChatModelProfile``."""
def __init__(
self,
*,
profile: ChatModelProfile | Mapping[str, Any],
settings: AppSettings | None = None,
recorder: UsageRecorder | None = None,
budget: BudgetPolicy | None = None,
id: str | None = None,
uuid: str | UUID | None = None,
) -> None:
[docs]
self.profile = profile if isinstance(profile, ChatModelProfile) else ChatModelProfile.model_validate(profile)
[docs]
self.settings = settings or AppSettings()
[docs]
self.recorder = recorder or UsageRecorder()
[docs]
self.uuid = _coerce_uuid(uuid)
self._id_is_profile_derived = id is None
[docs]
self.id = id or self._default_runtime_id()
self._bundle: CreatedLLMBundle | None = None
self._runnable: Any | None = None
log_event(
logger,
"llm.runtime.created",
llm_id=self.id,
llm_uuid=str(self.uuid),
profile_id=self.profile.id,
)
@property
[docs]
def bundle(self) -> CreatedLLMBundle:
"""Return the lazily built bundle."""
if self._bundle is None:
log_event(
logger,
"llm.bundle.build",
llm_id=self.id,
llm_uuid=str(self.uuid),
profile_id=self.profile.id,
model=self.profile.model,
alias=self.profile.alias,
provider=str(self.profile.provider) if self.profile.provider is not None else None,
)
self._bundle = self.profile.create_bundle(settings=self.settings)
self._runnable = self._bundle.llm
log_event(
logger,
"llm.bundle.ready",
llm_id=self.id,
llm_uuid=str(self.uuid),
model=self._bundle.model.as_langchain(),
)
return self._bundle
@property
[docs]
def runnable(self) -> Any:
"""Return the lazily built LangChain runnable."""
return self.bundle.llm
@property
[docs]
def model(self) -> ModelString:
"""Return the resolved model string."""
return self.bundle.model
@property
@property
[docs]
def usage_events(self) -> list[UsageEvent]:
"""Return recorded usage events."""
return self.recorder.events
@property
[docs]
def total_tokens(self) -> int:
"""Return total observed tokens."""
return self.recorder.total_tokens
@property
[docs]
def total_cost_usd(self) -> Decimal:
"""Return total observed/estimated cost."""
return self.recorder.total_cost_usd
@property
[docs]
def usage_summary(self) -> Any:
"""Return grouped observed usage totals."""
return self.recorder.summary()
[docs]
def set_profile(self, profile: ChatModelProfile | Mapping[str, Any], *, id: str | None = None) -> None:
"""Replace the profile and invalidate the cached runnable."""
self.profile = profile if isinstance(profile, ChatModelProfile) else ChatModelProfile.model_validate(profile)
if id is not None:
self.id = id
self._id_is_profile_derived = False
elif self._id_is_profile_derived:
self.id = self._default_runtime_id()
log_event(
logger,
"llm.profile.updated",
llm_id=self.id,
llm_uuid=str(self.uuid),
profile_id=self.profile.id,
)
self.invalidate()
[docs]
def invalidate(self) -> None:
"""Drop the cached runnable and bundle."""
self._bundle = None
self._runnable = None
log_event(logger, "llm.runtime.invalidated", llm_id=self.id, llm_uuid=str(self.uuid))
[docs]
def refresh(self, *, force: bool = False) -> CreatedLLMBundle:
"""Rebuild the runnable, optionally forcing model-default refresh."""
self.invalidate()
if force:
original = self.profile
self.profile = self.profile.model_copy(update={"force_model_refresh": True})
try:
return self.bundle
finally:
self.profile = original
return self.bundle
[docs]
def create_dspy_lm(self, **kwargs: Any) -> Any:
"""Create a DSPy ``LM`` from this runtime's profile and settings.
The returned object is a native DSPy LM. The richer program/runnable
layer belongs in ``ooai-agents``.
"""
from .dspy import create_dspy_lm
return create_dspy_lm(self, **kwargs)
[docs]
def invoke(self, input: Any, config: Mapping[str, Any] | None = None, **kwargs: Any) -> Any:
"""Invoke the underlying chat model and record observed usage."""
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(logger, "llm.invoke", llm_id=self.id, llm_uuid=str(self.uuid), method="invoke")
result = self.runnable.invoke(input, config=self._runtime_config(config), **kwargs)
self._record_response_usage(result)
return result
[docs]
async def ainvoke(self, input: Any, config: Mapping[str, Any] | None = None, **kwargs: Any) -> Any:
"""Async invoke the underlying chat model and record observed usage."""
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(logger, "llm.invoke", llm_id=self.id, llm_uuid=str(self.uuid), method="ainvoke")
result = await self.runnable.ainvoke(input, config=self._runtime_config(config), **kwargs)
self._record_response_usage(result)
return result
[docs]
def batch(
self,
inputs: Sequence[Any],
config: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> list[Any]:
"""Batch invoke the underlying chat model and record observed usage."""
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(
logger,
"llm.invoke",
llm_id=self.id,
llm_uuid=str(self.uuid),
method="batch",
input_count=len(inputs),
)
if hasattr(self.runnable, "batch"):
results = self.runnable.batch(inputs, config=self._runtime_config(config), **kwargs)
else:
return [self.invoke(item, config=config, **kwargs) for item in inputs]
for result in results:
self._record_response_usage(result)
return list(results)
[docs]
async def abatch(
self,
inputs: Sequence[Any],
config: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> list[Any]:
"""Async batch invoke the underlying chat model and record observed usage."""
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(
logger,
"llm.invoke",
llm_id=self.id,
llm_uuid=str(self.uuid),
method="abatch",
input_count=len(inputs),
)
if hasattr(self.runnable, "abatch"):
results = await self.runnable.abatch(inputs, config=self._runtime_config(config), **kwargs)
for result in results:
self._record_response_usage(result)
return list(results)
return [await self.ainvoke(item, config=config, **kwargs) for item in inputs]
[docs]
def stream(self, input: Any, config: Mapping[str, Any] | None = None, **kwargs: Any):
"""Stream from the underlying chat model."""
last_chunk = None
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(logger, "llm.invoke", llm_id=self.id, llm_uuid=str(self.uuid), method="stream")
for chunk in self.runnable.stream(input, config=self._runtime_config(config), **kwargs):
last_chunk = chunk
yield chunk
if last_chunk is not None:
self._record_response_usage(last_chunk)
[docs]
async def astream(self, input: Any, config: Mapping[str, Any] | None = None, **kwargs: Any):
"""Async stream from the underlying chat model."""
last_chunk = None
with logging_context(llm_id=self.id, llm_uuid=str(self.uuid)):
log_event(logger, "llm.invoke", llm_id=self.id, llm_uuid=str(self.uuid), method="astream")
async for chunk in self.runnable.astream(input, config=self._runtime_config(config), **kwargs):
last_chunk = chunk
yield chunk
if last_chunk is not None:
self._record_response_usage(last_chunk)
def _default_runtime_id(self) -> str:
"""Return the profile id or a stable runtime id derived from the UUID."""
return self.profile.id or f"llm-{self.uuid}"
def _runtime_config(self, config: Mapping[str, Any] | None = None) -> dict[str, Any]:
"""Merge profile tags/metadata/run name with caller runtime config."""
merged = dict(config or {})
if self.profile.run_name and "run_name" not in merged:
merged["run_name"] = self.profile.run_name
profile_tags = list(self.profile.tags)
caller_tags = list(merged.get("tags") or [])
if profile_tags or caller_tags:
merged["tags"] = [*profile_tags, *caller_tags]
profile_metadata = dict(self.profile.metadata)
if self.profile.cost_labels:
profile_metadata.setdefault("ooai_cost_labels", dict(self.profile.cost_labels))
caller_metadata = dict(merged.get("metadata") or {})
runtime_metadata = {
"ooai_llm_id": self.id,
"ooai_llm_uuid": str(self.uuid),
}
if self.profile.id is not None:
runtime_metadata["ooai_profile_id"] = self.profile.id
merged["metadata"] = {**profile_metadata, **caller_metadata, **runtime_metadata}
return merged
def _record_response_usage(self, response: Any) -> None:
"""Record response usage when metadata is available."""
metadata = dict(self.profile.metadata)
metadata.update(
{
"ooai_llm_id": self.id,
"ooai_llm_uuid": str(self.uuid),
}
)
if self.profile.id is not None:
metadata["ooai_profile_id"] = self.profile.id
event = record_langchain_response_usage(
self.recorder,
response=response,
model=self.model,
budget=self.budget,
settings=self.settings,
profile=self.metadata.capabilities.raw_profile,
count_source="provider_usage_metadata",
run_name=self.profile.run_name,
tags=self.profile.tags,
metadata=metadata,
cost_labels=self.profile.cost_labels,
)
if event is not None:
log_event(
logger,
"llm.usage.recorded",
llm_id=self.id,
llm_uuid=str(self.uuid),
model=event.model.as_langchain(),
total_tokens=event.total_tokens,
cost_usd=str(event.cost_usd) if event.cost_usd is not None else None,
)
def _coerce_uuid(value: str | UUID | None) -> UUID:
"""Return a UUID from a caller value or create a new one."""
if value is None:
return uuid4()
if isinstance(value, UUID):
return value
return UUID(str(value))