"""Optional DSPy model integration.
Purpose:
Create DSPy ``LM`` instances from the same model strings, serializable
profiles, and runtime identity used by the LangChain-first OOAI surface.
Design:
- Keep DSPy optional and lazily imported.
- Resolve all model choices to LiteLLM-style ``provider/model`` strings,
which is DSPy's constructor shape.
- Expose only model-substrate helpers here. DSPy programs, runnables,
LangGraph nodes, optimizers, and artifacts belong in ``ooai-agents``.
Examples:
>>> config = DSPyLMConfig(model="openai:gpt-5-mini", temperature=0)
>>> config.resolve_model_name()
'openai/gpt-5-mini'
"""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Literal
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
from .callbacks import BudgetPolicy, UsageEvent, UsageRecorder
from .factory import native_environment_overrides, resolve_factory_settings, resolve_model_string
from .metadata import ModelInfo, build_usage_snapshot, calculate_cost, get_model_info
from .profiles import CacheKeyPolicy, ChatModelProfile, LLM
from .providers import Provider
from .reasoning import ReasoningInput, build_reasoning_resolution
from .settings import AppSettings, ModelAliasName, ModelPresetName
from .types import ModelString
if TYPE_CHECKING:
from types import ModuleType
[docs]
DSPyModelType = Literal["chat", "text", "responses"]
"""DSPy model protocol style passed to ``dspy.LM``."""
[docs]
class DSPyDependencyError(ImportError):
"""Raised when DSPy helpers are used without the optional dependency."""
[docs]
class DSPyLMConfig(BaseModel):
"""Serializable configuration for creating a DSPy ``LM``.
Args:
model: Explicit OOAI model string, such as ``"openai:gpt-5-mini"``.
alias: Optional configured model alias when ``model`` is omitted.
provider: Optional provider for alias/preset resolution.
preset: Provider preset name when resolving by provider.
model_type: DSPy LM type: ``"chat"``, ``"text"``, or ``"responses"``.
lm_kwargs: DSPy/LiteLLM passthrough kwargs for provider-specific needs.
"""
[docs]
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
[docs]
model: str | None = None
[docs]
alias: ModelAliasName | None = None
[docs]
provider: str | Provider | None = None
[docs]
preset: ModelPresetName = "default"
[docs]
model_type: DSPyModelType = "chat"
[docs]
temperature: float | None = None
[docs]
max_tokens: int | None = Field(default=None, ge=1)
[docs]
top_p: float | None = Field(default=None, ge=0, le=1)
[docs]
frequency_penalty: float | None = None
[docs]
presence_penalty: float | None = None
[docs]
seed: int | None = None
[docs]
stop: str | list[str] | None = None
[docs]
cache: bool | None = None
[docs]
callbacks: list[Any] = Field(default_factory=list)
[docs]
num_retries: int | None = Field(default=None, ge=0)
[docs]
timeout: float | None = Field(default=None, ge=0)
[docs]
reasoning: ReasoningInput = None
[docs]
finetuning_model: str | None = None
[docs]
launch_kwargs: dict[str, Any] | None = None
[docs]
train_kwargs: dict[str, Any] | None = None
[docs]
use_developer_role: bool | None = None
[docs]
lm_kwargs: dict[str, Any] = Field(default_factory=dict)
[docs]
auto_refresh_models: bool | None = None
[docs]
force_model_refresh: bool = False
[docs]
runtime_id: str | None = None
[docs]
runtime_uuid: UUID | str | None = None
[docs]
profile_id: str | None = None
[docs]
cost_labels: dict[str, str] = Field(default_factory=dict)
@classmethod
[docs]
def from_profile(
cls,
profile: ChatModelProfile,
*,
model_type: DSPyModelType = "chat",
) -> DSPyLMConfig:
"""Build a DSPy config from a serializable chat-model profile."""
return cls(
model=profile.model,
alias=profile.alias,
provider=profile.provider,
preset=profile.preset,
model_type=model_type,
temperature=profile.temperature,
max_tokens=profile.max_tokens,
top_p=profile.top_p,
frequency_penalty=profile.frequency_penalty,
presence_penalty=profile.presence_penalty,
seed=profile.seed,
stop=profile.stop,
cache=_profile_cache_flag(profile.cache),
num_retries=profile.max_retries,
timeout=profile.timeout,
parallel_tool_calls=profile.parallel_tool_calls,
reasoning=profile.reasoning,
lm_kwargs=_profile_lm_kwargs(profile),
auto_refresh_models=profile.auto_refresh_models,
force_model_refresh=profile.force_model_refresh,
profile_id=profile.id,
tags=list(profile.tags),
metadata=dict(profile.metadata),
cost_labels=dict(profile.cost_labels),
)
@classmethod
[docs]
def from_runtime(
cls,
runtime: LLM,
*,
model_type: DSPyModelType = "chat",
) -> DSPyLMConfig:
"""Build a DSPy config from an ``LLM`` runtime."""
config = cls.from_profile(runtime.profile, model_type=model_type)
return config.model_copy(
update={
"runtime_id": runtime.id,
"runtime_uuid": runtime.uuid,
}
)
[docs]
def resolve_model_string(self, *, settings: AppSettings | None = None) -> ModelString:
"""Resolve this config to an OOAI model string."""
resolved_settings = resolve_factory_settings(
settings,
auto_refresh_models=self.auto_refresh_models,
force_model_refresh=self.force_model_refresh,
)
return resolve_model_string(
settings=resolved_settings,
model=self.model,
alias=self.alias,
provider=self.provider,
preset=self.preset,
)
[docs]
def resolve_model_name(self, *, settings: AppSettings | None = None) -> str:
"""Resolve this config to DSPy's LiteLLM-style model name."""
return self.resolve_model_string(settings=settings).as_litellm()
[docs]
def to_lm_kwargs(self, *, settings: AppSettings | None = None) -> dict[str, Any]:
"""Return keyword arguments for ``dspy.LM``."""
model = self.resolve_model_string(settings=settings)
kwargs: dict[str, Any] = {}
reasoning = build_reasoning_resolution(
model=model,
provider=self.provider,
reasoning=self.reasoning,
)
if reasoning is not None:
kwargs.update(reasoning.constructor_kwargs)
kwargs.update(self.lm_kwargs)
kwargs["model_type"] = self.model_type
_set_if_not_none(kwargs, "temperature", self.temperature)
_set_if_not_none(kwargs, "max_tokens", self.max_tokens)
_set_if_not_none(kwargs, "top_p", self.top_p)
_set_if_not_none(kwargs, "frequency_penalty", self.frequency_penalty)
_set_if_not_none(kwargs, "presence_penalty", self.presence_penalty)
_set_if_not_none(kwargs, "seed", self.seed)
_set_if_not_none(kwargs, "stop", self.stop)
_set_if_not_none(kwargs, "cache", self.cache)
_set_if_not_none(kwargs, "num_retries", self.num_retries)
_set_if_not_none(kwargs, "timeout", self.timeout)
_set_if_not_none(kwargs, "parallel_tool_calls", self.parallel_tool_calls)
_set_if_not_none(kwargs, "finetuning_model", self.finetuning_model)
_set_if_not_none(kwargs, "launch_kwargs", self.launch_kwargs)
_set_if_not_none(kwargs, "train_kwargs", self.train_kwargs)
_set_if_not_none(kwargs, "use_developer_role", self.use_developer_role)
if self.callbacks:
kwargs["callbacks"] = list(self.callbacks)
return kwargs
@dataclass(slots=True, frozen=True)
[docs]
class CreatedDSPyLMBundle:
"""DSPy LM plus resolved OOAI metadata."""
[docs]
def resolve_dspy_model_name(
model: str | ModelString | None = None,
*,
settings: AppSettings | None = None,
alias: ModelAliasName | None = None,
provider: Provider | str | None = None,
preset: ModelPresetName = "default",
auto_refresh_models: bool | None = None,
force_model_refresh: bool = False,
) -> str:
"""Resolve a model choice to DSPy's LiteLLM-style model name."""
config = DSPyLMConfig(
model=str(model) if model is not None else None,
alias=alias,
provider=provider,
preset=preset,
auto_refresh_models=auto_refresh_models,
force_model_refresh=force_model_refresh,
)
return config.resolve_model_name(settings=settings)
[docs]
def create_dspy_lm(
config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None = None,
*,
settings: AppSettings | None = None,
**kwargs: Any,
) -> Any:
"""Create a native DSPy ``LM`` from OOAI config.
Unknown keyword arguments are passed through to ``dspy.LM`` via
``DSPyLMConfig.lm_kwargs``.
"""
return create_dspy_lm_bundle(config, settings=settings, **kwargs).lm
[docs]
def create_dspy_lm_bundle(
config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None = None,
*,
settings: AppSettings | None = None,
**kwargs: Any,
) -> CreatedDSPyLMBundle:
"""Create a DSPy ``LM`` and return resolved OOAI metadata with it."""
dspy = _import_dspy()
lm_config, resolved_settings = _coerce_config(config, settings=settings, kwargs=kwargs)
model = lm_config.resolve_model_string(settings=resolved_settings)
lm_kwargs = lm_config.to_lm_kwargs(settings=resolved_settings)
with native_environment_overrides(resolved_settings):
lm = dspy.LM(model.as_litellm(), **lm_kwargs)
metadata = get_model_info(model=model, settings=resolved_settings, provider=lm_config.provider)
return CreatedDSPyLMBundle(
model=model,
lm=lm,
metadata=metadata,
config=lm_config,
trace_metadata=lm_config.trace_metadata(model=model, settings=resolved_settings),
)
[docs]
def build_dspy_usage_event(
*,
prediction: Any,
model: str | ModelString,
settings: AppSettings | None = None,
budget: BudgetPolicy | None = None,
profile: Mapping[str, Any] | None = None,
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> UsageEvent | None:
"""Build a usage event from a DSPy prediction when usage is available."""
usage_metadata = extract_dspy_usage(prediction)
if not usage_metadata:
return None
parsed_model = ModelString.parse(model).canonical()
model_info = get_model_info(model=parsed_model, settings=settings, profile=profile)
usage = build_usage_snapshot(usage_metadata)
cost = calculate_cost(model_info, usage)
return UsageEvent(
source="dspy",
model=parsed_model,
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
total_tokens=usage.resolved_total_tokens,
cost_usd=cost,
count_source="dspy_usage_metadata",
run_name=run_name,
tags=list(tags or []),
metadata=dict(metadata or {}),
cost_labels=dict(cost_labels or {}),
raw=usage.raw_usage,
)
[docs]
def record_dspy_usage(
recorder: UsageRecorder,
*,
prediction: Any,
model: str | ModelString,
settings: AppSettings | None = None,
budget: BudgetPolicy | None = None,
profile: Mapping[str, Any] | None = None,
run_name: str | None = None,
tags: list[str] | tuple[str, ...] | None = None,
metadata: Mapping[str, Any] | None = None,
cost_labels: Mapping[str, str] | None = None,
) -> UsageEvent | None:
"""Record DSPy prediction usage when the prediction exposes counts."""
event = build_dspy_usage_event(
prediction=prediction,
model=model,
settings=settings,
budget=budget,
profile=profile,
run_name=run_name,
tags=tags,
metadata=metadata,
cost_labels=cost_labels,
)
if event is None:
return None
return recorder.record(event, budget=budget)
def _import_dspy() -> "ModuleType":
try:
import dspy
except ModuleNotFoundError as exc:
raise DSPyDependencyError(
"DSPy support requires the optional dependency. "
'Install it with `pip install "ooai-llm[dspy]"` or `pdm add "ooai-llm[dspy]"`.'
) from exc
return dspy
def _coerce_config(
config: ChatModelProfile | DSPyLMConfig | LLM | str | ModelString | None,
*,
settings: AppSettings | None,
kwargs: Mapping[str, Any],
) -> tuple[DSPyLMConfig, AppSettings]:
updates = _split_config_updates(kwargs)
resolved_settings = settings
if isinstance(config, LLM):
resolved_settings = settings or config.settings
model_type = updates.pop("model_type", "chat")
return _apply_updates(DSPyLMConfig.from_runtime(config, model_type=model_type), updates), resolved_settings
if isinstance(config, DSPyLMConfig):
return _apply_updates(config, updates), resolved_settings or AppSettings()
if isinstance(config, ChatModelProfile):
model_type = updates.pop("model_type", "chat")
return _apply_updates(DSPyLMConfig.from_profile(config, model_type=model_type), updates), resolved_settings or AppSettings()
if isinstance(config, ModelString):
updates.setdefault("model", str(config))
elif isinstance(config, str):
updates.setdefault("model", config)
return DSPyLMConfig.model_validate(updates), resolved_settings or AppSettings()
def _split_config_updates(kwargs: Mapping[str, Any]) -> dict[str, Any]:
updates: dict[str, Any] = {}
lm_kwargs: dict[str, Any] = {}
fields = set(DSPyLMConfig.model_fields)
for key, value in kwargs.items():
if key == "lm_kwargs":
lm_kwargs.update(dict(value or {}))
elif key in fields:
updates[key] = value
else:
lm_kwargs[key] = value
if lm_kwargs:
existing = dict(updates.get("lm_kwargs") or {})
existing.update(lm_kwargs)
updates["lm_kwargs"] = existing
return updates
def _apply_updates(config: DSPyLMConfig, updates: Mapping[str, Any]) -> DSPyLMConfig:
merged = dict(updates)
if "lm_kwargs" in merged:
lm_kwargs = dict(config.lm_kwargs)
lm_kwargs.update(dict(merged["lm_kwargs"] or {}))
merged["lm_kwargs"] = lm_kwargs
return config.model_copy(update=merged)
def _profile_lm_kwargs(profile: ChatModelProfile) -> dict[str, Any]:
kwargs = dict(profile.constructor_kwargs)
nested_model_kwargs = kwargs.pop("model_kwargs", None)
if isinstance(nested_model_kwargs, Mapping):
kwargs.update(nested_model_kwargs)
kwargs.update(profile.model_kwargs)
if "max_retries" in kwargs and "num_retries" not in kwargs:
kwargs["num_retries"] = kwargs.pop("max_retries")
return kwargs
def _profile_cache_flag(cache: CacheKeyPolicy | bool | None) -> bool | None:
if isinstance(cache, bool) or cache is None:
return cache
return cache.enabled
def _extract_dspy_usage_raw(value: Any) -> Any:
if value is None:
return None
get_lm_usage = getattr(value, "get_lm_usage", None)
if callable(get_lm_usage):
raw = get_lm_usage()
if raw:
return raw
if isinstance(value, Mapping):
for key in ("lm_usage", "usage_metadata", "usage", "token_usage"):
if value.get(key):
return value[key]
for attr in ("lm_usage", "usage_metadata", "usage", "token_usage"):
raw = getattr(value, attr, None)
if raw:
return raw
history = getattr(value, "history", None)
if history:
return history
return None
def _normalize_dspy_usage(raw: Any) -> dict[str, Any]:
if isinstance(raw, Mapping):
direct = dict(raw)
if _has_token_keys(direct):
return direct
totals = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
found = False
for value in direct.values():
if isinstance(value, Mapping) and _has_token_keys(value):
snapshot = build_usage_snapshot(value)
totals["input_tokens"] += snapshot.input_tokens
totals["output_tokens"] += snapshot.output_tokens
totals["total_tokens"] += snapshot.resolved_total_tokens
found = True
if found:
totals["raw_usage"] = direct
return totals
return direct
if isinstance(raw, list | tuple):
totals = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
found = False
for item in raw:
if not isinstance(item, Mapping):
continue
normalized = _normalize_dspy_usage(item)
if not _has_token_keys(normalized):
continue
snapshot = build_usage_snapshot(normalized)
totals["input_tokens"] += snapshot.input_tokens
totals["output_tokens"] += snapshot.output_tokens
totals["total_tokens"] += snapshot.resolved_total_tokens
found = True
if found:
totals["raw_usage"] = {"items": list(raw)}
return totals
if hasattr(raw, "model_dump"):
return _normalize_dspy_usage(raw.model_dump())
if hasattr(raw, "__dict__"):
return _normalize_dspy_usage({key: value for key, value in vars(raw).items() if not key.startswith("_")})
return {}
def _has_token_keys(value: Mapping[str, Any]) -> bool:
keys = set(value)
return bool(
keys
& {
"input_tokens",
"prompt_tokens",
"prompt_token_count",
"input_token_count",
"output_tokens",
"completion_tokens",
"candidates_token_count",
"output_token_count",
"total_tokens",
"total_token_count",
}
)
def _set_if_not_none(kwargs: dict[str, Any], key: str, value: Any) -> None:
if value is not None:
kwargs[key] = value
__all__ = [
"CreatedDSPyLMBundle",
"DSPyDependencyError",
"DSPyLMConfig",
"DSPyModelType",
"build_dspy_usage_event",
"configure_dspy_lm",
"create_dspy_lm",
"create_dspy_lm_bundle",
"extract_dspy_usage",
"record_dspy_usage",
"resolve_dspy_model_name",
]