"""Usage and cost callback helpers.
Purpose:
Provide ergonomic callback helpers that work with LangChain usage metadata
and the native LiteLLM callback interface.
Design:
- Normalize usage and cost into a shared ``UsageEvent`` model.
- Offer a recorder object that can accumulate usage across many calls.
- Expose a LiteLLM-compatible success callback factory for cost and usage
tracking.
- Keep callback helpers transport-agnostic so they can support chat,
embeddings, and future model families.
Examples:
>>> recorder = UsageRecorder()
>>> callback = make_litellm_cost_callback(recorder)
>>> callable(callback)
True
"""
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
from decimal import Decimal
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, computed_field
from .logging import get_logger, log_event
from .metadata import build_usage_snapshot, calculate_cost, get_model_info
from .types import ModelString
[docs]
CountSource = Literal[
"provider_preflight",
"provider_usage_metadata",
"dspy_usage_metadata",
"local_tokenizer",
"framework_callback",
"approximation",
]
[docs]
logger = get_logger(__name__)
[docs]
class BudgetExceededError(RuntimeError):
"""Raised when a usage or cost budget is exceeded."""
[docs]
class UsageEvent(BaseModel):
"""Normalized usage and cost event.
Args:
source: Origin of the event, such as ``langchain`` or ``litellm``.
model: Typed model string.
input_tokens: Input token count.
output_tokens: Output token count.
total_tokens: Total token count.
cost_usd: Actual or estimated USD cost.
latency_ms: Measured latency in milliseconds when available.
count_source: Provenance for token counts.
run_name: Optional LangChain run name or logical application run.
tags: Optional framework/application tags.
metadata: Optional framework/application metadata.
cost_labels: Optional normalized labels used for cost grouping.
raw: Original payload fragments.
"""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
cost_usd: Decimal | None = None
[docs]
latency_ms: Decimal | None = None
[docs]
count_source: CountSource = "framework_callback"
[docs]
run_name: str | None = None
[docs]
cost_labels: dict[str, str] = Field(default_factory=dict)
[docs]
raw: dict[str, Any] = Field(default_factory=dict)
[docs]
class UsageSummary(BaseModel):
"""Aggregate view over recorded usage events."""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
total_cost_usd: Decimal = Decimal("0")
[docs]
by_model: dict[str, int] = Field(default_factory=dict)
[docs]
by_provider: dict[str, int] = Field(default_factory=dict)
[docs]
by_run: dict[str, int] = Field(default_factory=dict)
[docs]
class BudgetPolicy(BaseModel):
"""Simple budget and warning thresholds for usage tracking.
Args:
warn_cost_usd: Optional single-event cost warning threshold.
error_cost_usd: Optional single-event hard cost threshold.
warn_total_tokens: Optional single-event total-token warning threshold.
error_total_tokens: Optional single-event hard token threshold.
"""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
warn_cost_usd: Decimal | None = None
[docs]
error_cost_usd: Decimal | None = None
[docs]
warn_total_tokens: int | None = None
[docs]
error_total_tokens: int | None = None
[docs]
def check(self, event: UsageEvent) -> list[str]:
"""Return warning messages or raise when hard thresholds are crossed.
Args:
event: Usage event to evaluate.
Returns:
Warning messages.
Raises:
BudgetExceededError: If a hard threshold is exceeded.
"""
warnings: list[str] = []
if self.error_total_tokens is not None and event.total_tokens > self.error_total_tokens:
raise BudgetExceededError(
f"Total token budget exceeded: {event.total_tokens} > {self.error_total_tokens}."
)
if self.error_cost_usd is not None and event.cost_usd is not None and event.cost_usd > self.error_cost_usd:
raise BudgetExceededError(
f"Cost budget exceeded: {event.cost_usd} > {self.error_cost_usd}."
)
if self.warn_total_tokens is not None and event.total_tokens > self.warn_total_tokens:
warnings.append(
f"Total token warning threshold exceeded: {event.total_tokens} > {self.warn_total_tokens}."
)
if self.warn_cost_usd is not None and event.cost_usd is not None and event.cost_usd > self.warn_cost_usd:
warnings.append(
f"Cost warning threshold exceeded: {event.cost_usd} > {self.warn_cost_usd}."
)
return warnings
[docs]
class UsageRecorder(BaseModel):
"""In-memory recorder for normalized usage events."""
[docs]
model_config = ConfigDict(extra="forbid")
[docs]
events: list[UsageEvent] = Field(default_factory=list)
[docs]
warnings: list[str] = Field(default_factory=list)
[docs]
def record(self, event: UsageEvent, *, budget: BudgetPolicy | None = None) -> UsageEvent:
"""Record an event and optionally apply a budget policy."""
if budget is not None:
self.warnings.extend(budget.check(event))
self.events.append(event)
log_event(
logger,
"usage.recorded",
source=event.source,
model=event.model.as_langchain(),
total_tokens=event.total_tokens,
count_source=event.count_source,
cost_usd=str(event.cost_usd) if event.cost_usd is not None else None,
)
return event
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def total_tokens(self) -> int:
"""Return total tokens recorded so far."""
return sum(event.total_tokens for event in self.events)
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def total_cost_usd(self) -> Decimal:
"""Return total cost recorded so far."""
total = Decimal("0")
for event in self.events:
if event.cost_usd is not None:
total += event.cost_usd
return total
@computed_field # type: ignore[prop-decorator]
@property
@computed_field # type: ignore[prop-decorator]
@property
[docs]
def output_tokens(self) -> int:
"""Return output tokens recorded so far."""
return sum(event.output_tokens for event in self.events)
[docs]
def summary(self) -> UsageSummary:
"""Return an aggregate usage summary grouped by model, provider, and run."""
by_model: dict[str, int] = {}
by_provider: dict[str, int] = {}
by_run: dict[str, int] = {}
total_cost = Decimal("0")
for event in self.events:
model_name = event.model.as_langchain()
by_model[model_name] = by_model.get(model_name, 0) + event.total_tokens
if event.model.provider is not None:
provider = event.model.provider.value
by_provider[provider] = by_provider.get(provider, 0) + event.total_tokens
if event.run_name:
by_run[event.run_name] = by_run.get(event.run_name, 0) + event.total_tokens
if event.cost_usd is not None:
total_cost += event.cost_usd
return UsageSummary(
event_count=len(self.events),
input_tokens=self.input_tokens,
output_tokens=self.output_tokens,
total_tokens=self.total_tokens,
total_cost_usd=total_cost,
by_model=by_model,
by_provider=by_provider,
by_run=by_run,
)
[docs]
def build_langchain_usage_event(
*,
model: str | ModelString,
usage_metadata: Mapping[str, Any] | None,
cost_usd: Decimal | None = None,
latency_ms: Decimal | None = None,
count_source: CountSource = "framework_callback",
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> UsageEvent:
"""Build a normalized event from LangChain usage metadata."""
parsed_model = ModelString.parse(model).canonical()
usage = build_usage_snapshot(usage_metadata)
return UsageEvent(
source="langchain",
model=parsed_model,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.resolved_total_tokens,
cost_usd=cost_usd,
latency_ms=latency_ms,
count_source=count_source,
run_name=run_name,
tags=list(tags or []),
metadata=dict(metadata or {}),
cost_labels=dict(cost_labels or {}),
raw=usage.raw_usage,
)
[docs]
def make_litellm_cost_callback(
recorder: UsageRecorder,
*,
budget: BudgetPolicy | None = None,
) -> Any:
"""Return a LiteLLM success callback that records cost and usage.
Args:
recorder: Recorder instance to update.
budget: Optional budget policy to evaluate for each event.
Returns:
LiteLLM-compatible callback function.
"""
def _callback(
kwargs: Mapping[str, Any],
completion_response: Any,
start_time: datetime | float | int,
end_time: datetime | float | int,
) -> None:
model_name = str(kwargs.get("model") or getattr(completion_response, "model", "unknown"))
usage_payload = getattr(completion_response, "usage", None)
if usage_payload is None and isinstance(completion_response, Mapping):
usage_payload = completion_response.get("usage")
usage = build_usage_snapshot(_coerce_mapping(usage_payload))
latency_ms = _compute_latency_ms(start_time, end_time)
response_cost = kwargs.get("response_cost")
cost_usd = Decimal(str(response_cost)) if response_cost not in (None, "") else None
event = UsageEvent(
source="litellm",
model=ModelString.parse(model_name).canonical(),
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.resolved_total_tokens,
cost_usd=cost_usd,
latency_ms=latency_ms,
count_source="provider_usage_metadata",
raw={
"kwargs": dict(kwargs),
"usage": usage.raw_usage,
},
)
recorder.record(event, budget=budget)
return _callback
[docs]
def estimate_and_record_langchain_usage(
recorder: UsageRecorder,
*,
model: str | ModelString,
usage_metadata: Mapping[str, Any] | None,
budget: BudgetPolicy | None = None,
settings: Any = None,
profile: Mapping[str, Any] | None = None,
count_source: CountSource = "framework_callback",
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> UsageEvent:
"""Estimate cost from LangChain usage metadata and record the result.
Args:
recorder: Recorder instance to update.
model: Raw or typed model string.
usage_metadata: LangChain usage metadata.
budget: Optional budget policy.
settings: Optional app settings used for LiteLLM enrichment.
profile: Optional LangChain model profile.
Returns:
Recorded usage event.
"""
meta = get_model_info(model=model, settings=settings, profile=profile)
usage = build_usage_snapshot(usage_metadata)
cost = calculate_cost(meta, usage)
event = build_langchain_usage_event(
model=model,
usage_metadata=usage_metadata,
cost_usd=cost,
count_source=count_source,
run_name=run_name,
tags=tags,
metadata=metadata,
cost_labels=cost_labels,
)
return recorder.record(event, budget=budget)
[docs]
def record_langchain_response_usage(
recorder: UsageRecorder,
*,
response: Any,
model: str | ModelString,
budget: BudgetPolicy | None = None,
settings: Any = None,
profile: Mapping[str, Any] | None = None,
count_source: CountSource = "provider_usage_metadata",
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> UsageEvent | None:
"""Record usage from a LangChain response when usage metadata is present."""
usage_metadata = extract_usage_metadata(response)
if not usage_metadata:
return None
response_model = extract_response_model_name(response) or str(model)
return estimate_and_record_langchain_usage(
recorder,
model=response_model,
usage_metadata=usage_metadata,
budget=budget,
settings=settings,
profile=profile,
count_source=count_source,
run_name=run_name,
tags=tags,
metadata=metadata,
cost_labels=cost_labels,
)
[docs]
class LangChainUsageCallbackHandler:
"""LangChain callback handler that records observed LLM usage metadata."""
def __init__(
self,
recorder: UsageRecorder,
*,
model: str | ModelString,
budget: BudgetPolicy | None = None,
settings: Any = None,
profile: Mapping[str, Any] | None = None,
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> None:
[docs]
self.recorder = recorder
[docs]
self.settings = settings
[docs]
self.run_name = run_name
[docs]
self.cost_labels = dict(cost_labels or {})
[docs]
def on_llm_end(self, response: Any, **_: Any) -> None:
"""Record usage when LangChain reports an LLM result."""
record_langchain_response_usage(
self.recorder,
response=response,
model=self.model,
budget=self.budget,
settings=self.settings,
profile=self.profile,
count_source="framework_callback",
run_name=self.run_name,
tags=self.tags,
metadata=self.metadata,
cost_labels=self.cost_labels,
)
def _coerce_mapping(value: Any) -> dict[str, Any]:
"""Coerce a value into a plain dictionary when possible."""
if value is None:
return {}
if isinstance(value, Mapping):
return dict(value)
if hasattr(value, "model_dump"):
return dict(value.model_dump())
if hasattr(value, "dict"):
return dict(value.dict())
if hasattr(value, "__dict__"):
return {key: val for key, val in vars(value).items() if not key.startswith("_")}
return {}
def _compute_latency_ms(start_time: datetime | float | int, end_time: datetime | float | int) -> Decimal | None:
"""Return latency in milliseconds when both timestamps are usable."""
try:
if isinstance(start_time, datetime):
start_value = start_time.timestamp()
else:
start_value = float(start_time)
if isinstance(end_time, datetime):
end_value = end_time.timestamp()
else:
end_value = float(end_time)
except Exception:
return None
return Decimal(str((end_value - start_value) * 1000))