Source code for ooai_llm.dspy

"""Optional DSPy model integration.

Purpose:
    Create DSPy ``LM`` instances from the same model strings, serializable
    profiles, and runtime identity used by the LangChain-first OOAI surface.

Design:
    - Keep DSPy optional and lazily imported.
    - Resolve all model choices to LiteLLM-style ``provider/model`` strings,
      which is DSPy's constructor shape.
    - Expose only model-substrate helpers here. DSPy programs, runnables,
      LangGraph nodes, optimizers, and artifacts belong in ``ooai-agents``.

Examples:
    >>> config = DSPyLMConfig(model="openai:gpt-5-mini", temperature=0)
    >>> config.resolve_model_name()
    'openai/gpt-5-mini'
"""

from __future__ import annotations

from collections.abc import Mapping
from dataclasses import dataclass
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Literal
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field

from .callbacks import BudgetPolicy, UsageEvent, UsageRecorder
from .factory import native_environment_overrides, resolve_factory_settings, resolve_model_string
from .metadata import ModelInfo, build_usage_snapshot, calculate_cost, get_model_info
from .profiles import CacheKeyPolicy, ChatModelProfile, LLM
from .providers import Provider
from .reasoning import ReasoningInput, build_reasoning_resolution
from .settings import AppSettings, ModelAliasName, ModelPresetName
from .types import ModelString

if TYPE_CHECKING:
    from types import ModuleType

[docs] DSPyModelType = Literal["chat", "text", "responses"]
"""DSPy model protocol style passed to ``dspy.LM``."""
[docs] class DSPyDependencyError(ImportError): """Raised when DSPy helpers are used without the optional dependency."""
[docs] class DSPyLMConfig(BaseModel): """Serializable configuration for creating a DSPy ``LM``. Args: model: Explicit OOAI model string, such as ``"openai:gpt-5-mini"``. alias: Optional configured model alias when ``model`` is omitted. provider: Optional provider for alias/preset resolution. preset: Provider preset name when resolving by provider. model_type: DSPy LM type: ``"chat"``, ``"text"``, or ``"responses"``. lm_kwargs: DSPy/LiteLLM passthrough kwargs for provider-specific needs. """
[docs] model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs] model: str | None = None
[docs] alias: ModelAliasName | None = None
[docs] provider: str | Provider | None = None
[docs] preset: ModelPresetName = "default"
[docs] model_type: DSPyModelType = "chat"
[docs] temperature: float | None = None
[docs] max_tokens: int | None = Field(default=None, ge=1)
[docs] top_p: float | None = Field(default=None, ge=0, le=1)
[docs] frequency_penalty: float | None = None
[docs] presence_penalty: float | None = None
[docs] seed: int | None = None
[docs] stop: str | list[str] | None = None
[docs] cache: bool | None = None
[docs] callbacks: list[Any] = Field(default_factory=list)
[docs] num_retries: int | None = Field(default=None, ge=0)
[docs] timeout: float | None = Field(default=None, ge=0)
[docs] parallel_tool_calls: bool | None = None
[docs] reasoning: ReasoningInput = None
[docs] finetuning_model: str | None = None
[docs] launch_kwargs: dict[str, Any] | None = None
[docs] train_kwargs: dict[str, Any] | None = None
[docs] use_developer_role: bool | None = None
[docs] lm_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs] auto_refresh_models: bool | None = None
[docs] force_model_refresh: bool = False
[docs] runtime_id: str | None = None
[docs] runtime_uuid: UUID | str | None = None
[docs] profile_id: str | None = None
[docs] tags: list[str] = Field(default_factory=list)
[docs] metadata: dict[str, Any] = Field(default_factory=dict)
[docs] cost_labels: dict[str, str] = Field(default_factory=dict)
@classmethod
[docs] def from_profile( cls, profile: ChatModelProfile, *, model_type: DSPyModelType = "chat", ) -> DSPyLMConfig: """Build a DSPy config from a serializable chat-model profile.""" return cls( model=profile.model, alias=profile.alias, provider=profile.provider, preset=profile.preset, model_type=model_type, temperature=profile.temperature, max_tokens=profile.max_tokens, top_p=profile.top_p, frequency_penalty=profile.frequency_penalty, presence_penalty=profile.presence_penalty, seed=profile.seed, stop=profile.stop, cache=_profile_cache_flag(profile.cache), num_retries=profile.max_retries, timeout=profile.timeout, parallel_tool_calls=profile.parallel_tool_calls, reasoning=profile.reasoning, lm_kwargs=_profile_lm_kwargs(profile), auto_refresh_models=profile.auto_refresh_models, force_model_refresh=profile.force_model_refresh, profile_id=profile.id, tags=list(profile.tags), metadata=dict(profile.metadata), cost_labels=dict(profile.cost_labels), )
@classmethod
[docs] def from_runtime( cls, runtime: LLM, *, model_type: DSPyModelType = "chat", ) -> DSPyLMConfig: """Build a DSPy config from an ``LLM`` runtime.""" config = cls.from_profile(runtime.profile, model_type=model_type) return config.model_copy( update={ "runtime_id": runtime.id, "runtime_uuid": runtime.uuid, } )
[docs] def resolve_model_string(self, *, settings: AppSettings | None = None) -> ModelString: """Resolve this config to an OOAI model string.""" resolved_settings = resolve_factory_settings( settings, auto_refresh_models=self.auto_refresh_models, force_model_refresh=self.force_model_refresh, ) return resolve_model_string( settings=resolved_settings, model=self.model, alias=self.alias, provider=self.provider, preset=self.preset, )
[docs] def resolve_model_name(self, *, settings: AppSettings | None = None) -> str: """Resolve this config to DSPy's LiteLLM-style model name.""" return self.resolve_model_string(settings=settings).as_litellm()
[docs] def trace_metadata( self, *, model: ModelString | None = None, settings: AppSettings | None = None, ) -> dict[str, Any]: """Return JSON-safe runtime/model metadata for downstream wrappers.""" resolved_model = model or self.resolve_model_string(settings=settings) metadata = dict(self.metadata) if self.runtime_id is not None: metadata["ooai_runtime_id"] = self.runtime_id if self.runtime_uuid is not None: metadata["ooai_runtime_uuid"] = str(self.runtime_uuid) if self.profile_id is not None: metadata["ooai_profile_id"] = self.profile_id if self.cost_labels: metadata["ooai_cost_labels"] = dict(self.cost_labels) metadata["ooai_model"] = resolved_model.as_langchain() metadata["ooai_litellm_model"] = resolved_model.as_litellm() return metadata
[docs] def to_lm_kwargs(self, *, settings: AppSettings | None = None) -> dict[str, Any]: """Return keyword arguments for ``dspy.LM``.""" model = self.resolve_model_string(settings=settings) kwargs: dict[str, Any] = {} reasoning = build_reasoning_resolution( model=model, provider=self.provider, reasoning=self.reasoning, ) if reasoning is not None: kwargs.update(reasoning.constructor_kwargs) kwargs.update(self.lm_kwargs) kwargs["model_type"] = self.model_type _set_if_not_none(kwargs, "temperature", self.temperature) _set_if_not_none(kwargs, "max_tokens", self.max_tokens) _set_if_not_none(kwargs, "top_p", self.top_p) _set_if_not_none(kwargs, "frequency_penalty", self.frequency_penalty) _set_if_not_none(kwargs, "presence_penalty", self.presence_penalty) _set_if_not_none(kwargs, "seed", self.seed) _set_if_not_none(kwargs, "stop", self.stop) _set_if_not_none(kwargs, "cache", self.cache) _set_if_not_none(kwargs, "num_retries", self.num_retries) _set_if_not_none(kwargs, "timeout", self.timeout) _set_if_not_none(kwargs, "parallel_tool_calls", self.parallel_tool_calls) _set_if_not_none(kwargs, "finetuning_model", self.finetuning_model) _set_if_not_none(kwargs, "launch_kwargs", self.launch_kwargs) _set_if_not_none(kwargs, "train_kwargs", self.train_kwargs) _set_if_not_none(kwargs, "use_developer_role", self.use_developer_role) if self.callbacks: kwargs["callbacks"] = list(self.callbacks) return kwargs
@dataclass(slots=True, frozen=True)
[docs] class CreatedDSPyLMBundle: """DSPy LM plus resolved OOAI metadata."""
[docs] model: ModelString
[docs] lm: Any
[docs] metadata: ModelInfo
[docs] config: DSPyLMConfig
[docs] trace_metadata: dict[str, Any]
[docs] def resolve_dspy_model_name( model: str | ModelString | None = None, *, settings: AppSettings | None = None, alias: ModelAliasName | None = None, provider: Provider | str | None = None, preset: ModelPresetName = "default", auto_refresh_models: bool | None = None, force_model_refresh: bool = False, ) -> str: """Resolve a model choice to DSPy's LiteLLM-style model name.""" config = DSPyLMConfig( model=str(model) if model is not None else None, alias=alias, provider=provider, preset=preset, auto_refresh_models=auto_refresh_models, force_model_refresh=force_model_refresh, ) return config.resolve_model_name(settings=settings)
[docs] def create_dspy_lm( config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None = None, *, settings: AppSettings | None = None, **kwargs: Any, ) -> Any: """Create a native DSPy ``LM`` from OOAI config. Unknown keyword arguments are passed through to ``dspy.LM`` via ``DSPyLMConfig.lm_kwargs``. """ return create_dspy_lm_bundle(config, settings=settings, **kwargs).lm
[docs] def create_dspy_lm_bundle( config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None = None, *, settings: AppSettings | None = None, **kwargs: Any, ) -> CreatedDSPyLMBundle: """Create a DSPy ``LM`` and return resolved OOAI metadata with it.""" dspy = _import_dspy() lm_config, resolved_settings = _coerce_config(config, settings=settings, kwargs=kwargs) model = lm_config.resolve_model_string(settings=resolved_settings) lm_kwargs = lm_config.to_lm_kwargs(settings=resolved_settings) with native_environment_overrides(resolved_settings): lm = dspy.LM(model.as_litellm(), **lm_kwargs) metadata = get_model_info(model=model, settings=resolved_settings, provider=lm_config.provider) return CreatedDSPyLMBundle( model=model, lm=lm, metadata=metadata, config=lm_config, trace_metadata=lm_config.trace_metadata(model=model, settings=resolved_settings), )
[docs] def configure_dspy_lm( config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None = None, *, settings: AppSettings | None = None, configure_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Any: """Create a DSPy ``LM``, configure DSPy's global settings, and return it.""" dspy = _import_dspy() lm = create_dspy_lm(config, settings=settings, **kwargs) payload = {"lm": lm, **dict(configure_kwargs or {})} configure = getattr(dspy, "configure", None) if callable(configure): configure(**payload) return lm settings_obj = getattr(dspy, "settings", None) settings_configure = getattr(settings_obj, "configure", None) if callable(settings_configure): settings_configure(**payload) return lm raise RuntimeError("Installed DSPy package does not expose dspy.configure or dspy.settings.configure.")
[docs] def extract_dspy_usage(value: Any) -> dict[str, Any] | None: """Extract normalized token usage from a DSPy prediction or LM object.""" raw = _extract_dspy_usage_raw(value) if not raw: return None return _normalize_dspy_usage(raw)
[docs] def build_dspy_usage_event( *, prediction: Any, model: str | ModelString, settings: AppSettings | None = None, budget: BudgetPolicy | None = None, profile: Mapping[str, Any] | None = None, run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> UsageEvent | None: """Build a usage event from a DSPy prediction when usage is available.""" usage_metadata = extract_dspy_usage(prediction) if not usage_metadata: return None parsed_model = ModelString.parse(model).canonical() model_info = get_model_info(model=parsed_model, settings=settings, profile=profile) usage = build_usage_snapshot(usage_metadata) cost = calculate_cost(model_info, usage) return UsageEvent( source="dspy", model=parsed_model, input_tokens=usage.input_tokens, output_tokens=usage.output_tokens, total_tokens=usage.resolved_total_tokens, cost_usd=cost, count_source="dspy_usage_metadata", run_name=run_name, tags=list(tags or []), metadata=dict(metadata or {}), cost_labels=dict(cost_labels or {}), raw=usage.raw_usage, )
[docs] def record_dspy_usage( recorder: UsageRecorder, *, prediction: Any, model: str | ModelString, settings: AppSettings | None = None, budget: BudgetPolicy | None = None, profile: Mapping[str, Any] | None = None, run_name: str | None = None, tags: list[str] | tuple[str, ...] | None = None, metadata: Mapping[str, Any] | None = None, cost_labels: Mapping[str, str] | None = None, ) -> UsageEvent | None: """Record DSPy prediction usage when the prediction exposes counts.""" event = build_dspy_usage_event( prediction=prediction, model=model, settings=settings, budget=budget, profile=profile, run_name=run_name, tags=tags, metadata=metadata, cost_labels=cost_labels, ) if event is None: return None return recorder.record(event, budget=budget)
def _import_dspy() -> "ModuleType": try: import dspy except ModuleNotFoundError as exc: raise DSPyDependencyError( "DSPy support requires the optional dependency. " 'Install it with `pip install "ooai-llm[dspy]"` or `pdm add "ooai-llm[dspy]"`.' ) from exc return dspy def _coerce_config( config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None, *, settings: AppSettings | None, kwargs: Mapping[str, Any], ) -> tuple[DSPyLMConfig, AppSettings]: updates = _split_config_updates(kwargs) resolved_settings = settings if isinstance(config, LLM): resolved_settings = settings or config.settings model_type = updates.pop("model_type", "chat") return _apply_updates(DSPyLMConfig.from_runtime(config, model_type=model_type), updates), resolved_settings if isinstance(config, DSPyLMConfig): return _apply_updates(config, updates), resolved_settings or AppSettings() if isinstance(config, ChatModelProfile): model_type = updates.pop("model_type", "chat") return _apply_updates(DSPyLMConfig.from_profile(config, model_type=model_type), updates), resolved_settings or AppSettings() if isinstance(config, ModelString): updates.setdefault("model", str(config)) elif isinstance(config, str): updates.setdefault("model", config) return DSPyLMConfig.model_validate(updates), resolved_settings or AppSettings() def _split_config_updates(kwargs: Mapping[str, Any]) -> dict[str, Any]: updates: dict[str, Any] = {} lm_kwargs: dict[str, Any] = {} fields = set(DSPyLMConfig.model_fields) for key, value in kwargs.items(): if key == "lm_kwargs": lm_kwargs.update(dict(value or {})) elif key in fields: updates[key] = value else: lm_kwargs[key] = value if lm_kwargs: existing = dict(updates.get("lm_kwargs") or {}) existing.update(lm_kwargs) updates["lm_kwargs"] = existing return updates def _apply_updates(config: DSPyLMConfig, updates: Mapping[str, Any]) -> DSPyLMConfig: merged = dict(updates) if "lm_kwargs" in merged: lm_kwargs = dict(config.lm_kwargs) lm_kwargs.update(dict(merged["lm_kwargs"] or {})) merged["lm_kwargs"] = lm_kwargs return config.model_copy(update=merged) def _profile_lm_kwargs(profile: ChatModelProfile) -> dict[str, Any]: kwargs = dict(profile.constructor_kwargs) nested_model_kwargs = kwargs.pop("model_kwargs", None) if isinstance(nested_model_kwargs, Mapping): kwargs.update(nested_model_kwargs) kwargs.update(profile.model_kwargs) if "max_retries" in kwargs and "num_retries" not in kwargs: kwargs["num_retries"] = kwargs.pop("max_retries") return kwargs def _profile_cache_flag(cache: CacheKeyPolicy | bool | None) -> bool | None: if isinstance(cache, bool) or cache is None: return cache return cache.enabled def _extract_dspy_usage_raw(value: Any) -> Any: if value is None: return None get_lm_usage = getattr(value, "get_lm_usage", None) if callable(get_lm_usage): raw = get_lm_usage() if raw: return raw if isinstance(value, Mapping): for key in ("lm_usage", "usage_metadata", "usage", "token_usage"): if value.get(key): return value[key] for attr in ("lm_usage", "usage_metadata", "usage", "token_usage"): raw = getattr(value, attr, None) if raw: return raw history = getattr(value, "history", None) if history: return history return None def _normalize_dspy_usage(raw: Any) -> dict[str, Any]: if isinstance(raw, Mapping): direct = dict(raw) if _has_token_keys(direct): return direct totals = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} found = False for value in direct.values(): if isinstance(value, Mapping) and _has_token_keys(value): snapshot = build_usage_snapshot(value) totals["input_tokens"] += snapshot.input_tokens totals["output_tokens"] += snapshot.output_tokens totals["total_tokens"] += snapshot.resolved_total_tokens found = True if found: totals["raw_usage"] = direct return totals return direct if isinstance(raw, list | tuple): totals = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} found = False for item in raw: if not isinstance(item, Mapping): continue normalized = _normalize_dspy_usage(item) if not _has_token_keys(normalized): continue snapshot = build_usage_snapshot(normalized) totals["input_tokens"] += snapshot.input_tokens totals["output_tokens"] += snapshot.output_tokens totals["total_tokens"] += snapshot.resolved_total_tokens found = True if found: totals["raw_usage"] = {"items": list(raw)} return totals if hasattr(raw, "model_dump"): return _normalize_dspy_usage(raw.model_dump()) if hasattr(raw, "__dict__"): return _normalize_dspy_usage({key: value for key, value in vars(raw).items() if not key.startswith("_")}) return {} def _has_token_keys(value: Mapping[str, Any]) -> bool: keys = set(value) return bool( keys & { "input_tokens", "prompt_tokens", "prompt_token_count", "input_token_count", "output_tokens", "completion_tokens", "candidates_token_count", "output_token_count", "total_tokens", "total_token_count", } ) def _set_if_not_none(kwargs: dict[str, Any], key: str, value: Any) -> None: if value is not None: kwargs[key] = value __all__ = [ "CreatedDSPyLMBundle", "DSPyDependencyError", "DSPyLMConfig", "DSPyModelType", "build_dspy_usage_event", "configure_dspy_lm", "create_dspy_lm", "create_dspy_lm_bundle", "extract_dspy_usage", "record_dspy_usage", "resolve_dspy_model_name", ]