Source code for ooai_llm.callbacks

"""Usage and cost callback helpers.

Purpose:
    Provide ergonomic callback helpers that work with LangChain usage metadata
    and the native LiteLLM callback interface.

Design:
    - Normalize usage and cost into a shared ``UsageEvent`` model.
    - Offer a recorder object that can accumulate usage across many calls.
    - Expose a LiteLLM-compatible success callback factory for cost and usage
      tracking.
    - Keep callback helpers transport-agnostic so they can support chat,
      embeddings, and future model families.

Examples:
    >>> recorder = UsageRecorder()
    >>> callback = make_litellm_cost_callback(recorder)
    >>> callable(callback)
    True
"""

from __future__ import annotations

from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, computed_field

from .logging import get_logger, log_event
from .metadata import build_usage_snapshot, calculate_cost, get_model_info
from .types import ModelString

[docs] CountSource = Literal[ "provider_preflight", "provider_usage_metadata", "dspy_usage_metadata", "local_tokenizer", "framework_callback", "approximation", ]
[docs] logger = get_logger(__name__)
[docs] class BudgetExceededError(RuntimeError): """Raised when a usage or cost budget is exceeded."""
[docs] class UsageEvent(BaseModel): """Normalized usage and cost event. Args: source: Origin of the event, such as ``langchain`` or ``litellm``. model: Typed model string. input_tokens: Input token count. output_tokens: Output token count. total_tokens: Total token count. cost_usd: Actual or estimated USD cost. latency_ms: Measured latency in milliseconds when available. count_source: Provenance for token counts. run_name: Optional LangChain run name or logical application run. tags: Optional framework/application tags. metadata: Optional framework/application metadata. cost_labels: Optional normalized labels used for cost grouping. raw: Original payload fragments. """
[docs] model_config = ConfigDict(extra="forbid")
[docs] source: str
[docs] model: ModelString
[docs] input_tokens: int = 0
[docs] output_tokens: int = 0
[docs] total_tokens: int = 0
[docs] cost_usd: Decimal | None = None
[docs] latency_ms: Decimal | None = None
[docs] count_source: CountSource = "framework_callback"
[docs] run_name: str | None = None
[docs] tags: list[str] = Field(default_factory=list)
[docs] metadata: dict[str, Any] = Field(default_factory=dict)
[docs] cost_labels: dict[str, str] = Field(default_factory=dict)
[docs] raw: dict[str, Any] = Field(default_factory=dict)
[docs] class UsageSummary(BaseModel): """Aggregate view over recorded usage events."""
[docs] model_config = ConfigDict(extra="forbid")
[docs] event_count: int = 0
[docs] input_tokens: int = 0
[docs] output_tokens: int = 0
[docs] total_tokens: int = 0
[docs] total_cost_usd: Decimal = Decimal("0")
[docs] by_model: dict[str, int] = Field(default_factory=dict)
[docs] by_provider: dict[str, int] = Field(default_factory=dict)
[docs] by_run: dict[str, int] = Field(default_factory=dict)
[docs] class BudgetPolicy(BaseModel): """Simple budget and warning thresholds for usage tracking. Args: warn_cost_usd: Optional single-event cost warning threshold. error_cost_usd: Optional single-event hard cost threshold. warn_total_tokens: Optional single-event total-token warning threshold. error_total_tokens: Optional single-event hard token threshold. """
[docs] model_config = ConfigDict(extra="forbid")
[docs] warn_cost_usd: Decimal | None = None
[docs] error_cost_usd: Decimal | None = None
[docs] warn_total_tokens: int | None = None
[docs] error_total_tokens: int | None = None
[docs] def check(self, event: UsageEvent) -> list[str]: """Return warning messages or raise when hard thresholds are crossed. Args: event: Usage event to evaluate. Returns: Warning messages. Raises: BudgetExceededError: If a hard threshold is exceeded. """ warnings: list[str] = [] if self.error_total_tokens is not None and event.total_tokens > self.error_total_tokens: raise BudgetExceededError( f"Total token budget exceeded: {event.total_tokens} > {self.error_total_tokens}." ) if self.error_cost_usd is not None and event.cost_usd is not None and event.cost_usd > self.error_cost_usd: raise BudgetExceededError( f"Cost budget exceeded: {event.cost_usd} > {self.error_cost_usd}." ) if self.warn_total_tokens is not None and event.total_tokens > self.warn_total_tokens: warnings.append( f"Total token warning threshold exceeded: {event.total_tokens} > {self.warn_total_tokens}." ) if self.warn_cost_usd is not None and event.cost_usd is not None and event.cost_usd > self.warn_cost_usd: warnings.append( f"Cost warning threshold exceeded: {event.cost_usd} > {self.warn_cost_usd}." ) return warnings
[docs] class UsageRecorder(BaseModel): """In-memory recorder for normalized usage events."""
[docs] model_config = ConfigDict(extra="forbid")
[docs] events: list[UsageEvent] = Field(default_factory=list)
[docs] warnings: list[str] = Field(default_factory=list)
[docs] def record(self, event: UsageEvent, *, budget: BudgetPolicy | None = None) -> UsageEvent: """Record an event and optionally apply a budget policy.""" if budget is not None: self.warnings.extend(budget.check(event)) self.events.append(event) log_event( logger, "usage.recorded", source=event.source, model=event.model.as_langchain(), total_tokens=event.total_tokens, count_source=event.count_source, cost_usd=str(event.cost_usd) if event.cost_usd is not None else None, ) return event
@computed_field # type: ignore[prop-decorator] @property
[docs] def total_tokens(self) -> int: """Return total tokens recorded so far.""" return sum(event.total_tokens for event in self.events)
@computed_field # type: ignore[prop-decorator] @property
[docs] def total_cost_usd(self) -> Decimal: """Return total cost recorded so far.""" total = Decimal("0") for event in self.events: if event.cost_usd is not None: total += event.cost_usd return total
@computed_field # type: ignore[prop-decorator] @property
[docs] def input_tokens(self) -> int: """Return input tokens recorded so far.""" return sum(event.input_tokens for event in self.events)
@computed_field # type: ignore[prop-decorator] @property
[docs] def output_tokens(self) -> int: """Return output tokens recorded so far.""" return sum(event.output_tokens for event in self.events)
[docs] def summary(self) -> UsageSummary: """Return an aggregate usage summary grouped by model, provider, and run.""" by_model: dict[str, int] = {} by_provider: dict[str, int] = {} by_run: dict[str, int] = {} total_cost = Decimal("0") for event in self.events: model_name = event.model.as_langchain() by_model[model_name] = by_model.get(model_name, 0) + event.total_tokens if event.model.provider is not None: provider = event.model.provider.value by_provider[provider] = by_provider.get(provider, 0) + event.total_tokens if event.run_name: by_run[event.run_name] = by_run.get(event.run_name, 0) + event.total_tokens if event.cost_usd is not None: total_cost += event.cost_usd return UsageSummary( event_count=len(self.events), input_tokens=self.input_tokens, output_tokens=self.output_tokens, total_tokens=self.total_tokens, total_cost_usd=total_cost, by_model=by_model, by_provider=by_provider, by_run=by_run, )
[docs] def build_langchain_usage_event( *, model: str | ModelString, usage_metadata: Mapping[str, Any] | None, cost_usd: Decimal | None = None, latency_ms: Decimal | None = None, count_source: CountSource = "framework_callback", run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> UsageEvent: """Build a normalized event from LangChain usage metadata.""" parsed_model = ModelString.parse(model).canonical() usage = build_usage_snapshot(usage_metadata) return UsageEvent( source="langchain", model=parsed_model, input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, total_tokens=usage.resolved_total_tokens, cost_usd=cost_usd, latency_ms=latency_ms, count_source=count_source, run_name=run_name, tags=list(tags or []), metadata=dict(metadata or {}), cost_labels=dict(cost_labels or {}), raw=usage.raw_usage, )
[docs] def make_litellm_cost_callback( recorder: UsageRecorder, *, budget: BudgetPolicy | None = None, ) -> Any: """Return a LiteLLM success callback that records cost and usage. Args: recorder: Recorder instance to update. budget: Optional budget policy to evaluate for each event. Returns: LiteLLM-compatible callback function. """ def _callback( kwargs: Mapping[str, Any], completion_response: Any, start_time: datetime | float | int, end_time: datetime | float | int, ) -> None: model_name = str(kwargs.get("model") or getattr(completion_response, "model", "unknown")) usage_payload = getattr(completion_response, "usage", None) if usage_payload is None and isinstance(completion_response, Mapping): usage_payload = completion_response.get("usage") usage = build_usage_snapshot(_coerce_mapping(usage_payload)) latency_ms = _compute_latency_ms(start_time, end_time) response_cost = kwargs.get("response_cost") cost_usd = Decimal(str(response_cost)) if response_cost not in (None, "") else None event = UsageEvent( source="litellm", model=ModelString.parse(model_name).canonical(), input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, total_tokens=usage.resolved_total_tokens, cost_usd=cost_usd, latency_ms=latency_ms, count_source="provider_usage_metadata", raw={ "kwargs": dict(kwargs), "usage": usage.raw_usage, }, ) recorder.record(event, budget=budget) return _callback
[docs] def estimate_and_record_langchain_usage( recorder: UsageRecorder, *, model: str | ModelString, usage_metadata: Mapping[str, Any] | None, budget: BudgetPolicy | None = None, settings: Any = None, profile: Mapping[str, Any] | None = None, count_source: CountSource = "framework_callback", run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> UsageEvent: """Estimate cost from LangChain usage metadata and record the result. Args: recorder: Recorder instance to update. model: Raw or typed model string. usage_metadata: LangChain usage metadata. budget: Optional budget policy. settings: Optional app settings used for LiteLLM enrichment. profile: Optional LangChain model profile. Returns: Recorded usage event. """ meta = get_model_info(model=model, settings=settings, profile=profile) usage = build_usage_snapshot(usage_metadata) cost = calculate_cost(meta, usage) event = build_langchain_usage_event( model=model, usage_metadata=usage_metadata, cost_usd=cost, count_source=count_source, run_name=run_name, tags=tags, metadata=metadata, cost_labels=cost_labels, ) return recorder.record(event, budget=budget)
[docs] def extract_usage_metadata(value: Any) -> dict[str, Any] | None: """Extract best-effort usage metadata from LangChain-style responses.""" if value is None: return None if isinstance(value, Mapping): for key in ("usage_metadata", "usage", "token_usage"): nested = value.get(key) if nested: return _coerce_mapping(nested) response_metadata = value.get("response_metadata") if response_metadata: extracted = extract_usage_metadata(response_metadata) if extracted: return extracted llm_output = value.get("llm_output") if llm_output: extracted = extract_usage_metadata(llm_output) if extracted: return extracted for attr in ("usage_metadata", "usage"): nested = getattr(value, attr, None) if nested: return _coerce_mapping(nested) response_metadata = getattr(value, "response_metadata", None) if response_metadata: extracted = extract_usage_metadata(response_metadata) if extracted: return extracted llm_output = getattr(value, "llm_output", None) if llm_output: extracted = extract_usage_metadata(llm_output) if extracted: return extracted generations = getattr(value, "generations", None) if generations: for generation_group in generations: items = generation_group if isinstance(generation_group, list | tuple) else [generation_group] for generation in items: message = getattr(generation, "message", None) extracted = extract_usage_metadata(message or generation) if extracted: return extracted return None
[docs] def extract_response_model_name(value: Any) -> str | None: """Extract a best-effort model name from a LangChain-style response.""" if value is None: return None if isinstance(value, Mapping): for key in ("model", "model_name"): if value.get(key): return str(value[key]) llm_output = value.get("llm_output") if isinstance(llm_output, Mapping): for key in ("model", "model_name"): if llm_output.get(key): return str(llm_output[key]) for attr in ("model", "model_name"): raw = getattr(value, attr, None) if raw: return str(raw) llm_output = getattr(value, "llm_output", None) if isinstance(llm_output, Mapping): for key in ("model", "model_name"): if llm_output.get(key): return str(llm_output[key]) return None
[docs] def record_langchain_response_usage( recorder: UsageRecorder, *, response: Any, model: str | ModelString, budget: BudgetPolicy | None = None, settings: Any = None, profile: Mapping[str, Any] | None = None, count_source: CountSource = "provider_usage_metadata", run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> UsageEvent | None: """Record usage from a LangChain response when usage metadata is present.""" usage_metadata = extract_usage_metadata(response) if not usage_metadata: return None response_model = extract_response_model_name(response) or str(model) return estimate_and_record_langchain_usage( recorder, model=response_model, usage_metadata=usage_metadata, budget=budget, settings=settings, profile=profile, count_source=count_source, run_name=run_name, tags=tags, metadata=metadata, cost_labels=cost_labels, )
[docs] class LangChainUsageCallbackHandler: """LangChain callback handler that records observed LLM usage metadata.""" def __init__( self, recorder: UsageRecorder, *, model: str | ModelString, budget: BudgetPolicy | None = None, settings: Any = None, profile: Mapping[str, Any] | None = None, run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> None:
[docs] self.recorder = recorder
[docs] self.model = model
[docs] self.budget = budget
[docs] self.settings = settings
[docs] self.profile = profile
[docs] self.run_name = run_name
[docs] self.tags = list(tags or [])
[docs] self.metadata = dict(metadata or {})
[docs] self.cost_labels = dict(cost_labels or {})
[docs] def on_llm_end(self, response: Any, **_: Any) -> None: """Record usage when LangChain reports an LLM result.""" record_langchain_response_usage( self.recorder, response=response, model=self.model, budget=self.budget, settings=self.settings, profile=self.profile, count_source="framework_callback", run_name=self.run_name, tags=self.tags, metadata=self.metadata, cost_labels=self.cost_labels, )
def _coerce_mapping(value: Any) -> dict[str, Any]: """Coerce a value into a plain dictionary when possible.""" if value is None: return {} if isinstance(value, Mapping): return dict(value) if hasattr(value, "model_dump"): return dict(value.model_dump()) if hasattr(value, "dict"): return dict(value.dict()) if hasattr(value, "__dict__"): return {key: val for key, val in vars(value).items() if not key.startswith("_")} return {} def _compute_latency_ms(start_time: datetime | float | int, end_time: datetime | float | int) -> Decimal | None: """Return latency in milliseconds when both timestamps are usable.""" try: if isinstance(start_time, datetime): start_value = start_time.timestamp() else: start_value = float(start_time) if isinstance(end_time, datetime): end_value = end_time.timestamp() else: end_value = float(end_time) except Exception: return None return Decimal(str((end_value - start_value) * 1000))