Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pydantic
import websockets
from openai.types.realtime import realtime_audio_config as _rt_audio_config
from openai.types.realtime.audio_transcription import AudioTranscription
from openai.types.realtime.conversation_item import (
ConversationItem,
ConversationItem as OpenAIConversationItem,
Expand Down Expand Up @@ -87,6 +88,7 @@
from agents.realtime._default_tracker import ModelAudioTracker
from agents.realtime.audio_formats import to_realtime_audio_format
from agents.tool import FunctionTool, Tool
from agents.util._pydantic import coerce_model_with_literal_fallback
from agents.util._types import MaybeAwaitable

from ..exceptions import UserError
Expand Down Expand Up @@ -883,7 +885,9 @@ def _get_session_config(
"modalities", DEFAULT_MODEL_SETTINGS.get("modalities")
),
audio=OpenAIRealtimeAudioConfig(
input=OpenAIRealtimeAudioInput(**audio_input_args), # type: ignore[arg-type]
input=OpenAIRealtimeAudioInput(
**_AudioTranscriptionHelper.prepare_audio_input_args(audio_input_args)
),
output=OpenAIRealtimeAudioOutput(**audio_output_args), # type: ignore[arg-type]
),
tools=cast(
Expand Down Expand Up @@ -958,6 +962,38 @@ async def connect(self, options: RealtimeModelConfig) -> None:
await super().connect(sip_options)


class _AudioTranscriptionHelper:
"""Helpers for handling transcription configs with forward compatibility."""

@staticmethod
def prepare_audio_input_args(audio_input_args: dict[str, Any]) -> dict[str, Any]:
"""Prepare audio input args, allowing newer transcription model names."""
prepared_args = dict(audio_input_args)
transcription_config = prepared_args.get("transcription")
if transcription_config is None:
return prepared_args

prepared_args["transcription"] = _AudioTranscriptionHelper._coerce_audio_transcription(
transcription_config
)
return prepared_args

@staticmethod
def _coerce_audio_transcription(transcription_config: Any) -> Any:
"""Convert transcription config into an AudioTranscription, tolerating new model names."""
if isinstance(transcription_config, AudioTranscription):
return transcription_config

if not isinstance(transcription_config, Mapping):
return transcription_config

return coerce_model_with_literal_fallback(
AudioTranscription,
transcription_config,
literal_error_locs=[("model",), ("transcription", "model")],
)


class _ConversionHelper:
@classmethod
def conversation_item_to_realtime_message_item(
Expand Down
57 changes: 57 additions & 0 deletions src/agents/util/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeVar, cast

import pydantic

# Helpers to tolerate forward-compatible Pydantic literal changes (e.g., date-suffixed model names).

_PydanticModelT = TypeVar("_PydanticModelT", bound="PydanticModelProtocol")


class PydanticModelProtocol(Protocol):
"""Subset of the Pydantic API we need for validation and construction."""

@classmethod
def model_validate(cls: type[_PydanticModelT], data: Any) -> _PydanticModelT: ...

@classmethod
def model_construct(cls: type[_PydanticModelT], **kwargs: Any) -> _PydanticModelT: ...


def coerce_model_with_literal_fallback(
model_cls: type[_PydanticModelT],
data: Any,
*,
literal_error_locs: Sequence[tuple[str, ...]],
) -> _PydanticModelT:
"""Validate data and fall back to model_construct when literal errors occur."""
if isinstance(data, model_cls):
return data

if not isinstance(data, Mapping):
return cast(_PydanticModelT, data)

try:
return model_cls.model_validate(data)
except pydantic.ValidationError as exc:
if _has_literal_error(exc, literal_error_locs):
return model_cls.model_construct(**dict(data))
raise


def _has_literal_error(
exc: pydantic.ValidationError, literal_error_locs: Sequence[tuple[str, ...]]
) -> bool:
"""Return True when a literal_error matches one of the provided locations."""
literal_locs = set(literal_error_locs)
for error in exc.errors():
if error.get("type") != "literal_error":
continue

loc = tuple(error.get("loc") or ())
if loc in literal_locs:
return True

return False
38 changes: 37 additions & 1 deletion tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from typing import Any, cast
from unittest.mock import AsyncMock, Mock, patch

import pydantic
import pytest
import websockets
from openai.types.realtime.audio_transcription import AudioTranscription

from agents import Agent
from agents.exceptions import UserError
Expand All @@ -21,7 +23,10 @@
RealtimeModelSendToolOutput,
RealtimeModelSendUserInput,
)
from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel
from agents.realtime.openai_realtime import (
OpenAIRealtimeWebSocketModel,
_AudioTranscriptionHelper,
)


class TestOpenAIRealtimeWebSocketModel:
Expand Down Expand Up @@ -646,6 +651,15 @@ def test_get_and_update_session_config(self, model):
assert cfg.audio is not None and cfg.audio.output is not None
assert cfg.audio.output.voice == "verse"

def test_session_config_allows_new_transcription_models(self, model):
cfg = model._get_session_config(
{"input_audio_transcription": {"model": "gpt-4o-mini-transcribe-2025-12-15"}}
)
assert cfg.audio is not None
assert cfg.audio.input is not None
assert cfg.audio.input.transcription is not None
assert cfg.audio.input.transcription.model == "gpt-4o-mini-transcribe-2025-12-15"

def test_session_config_defaults_audio_formats_when_not_call(self, model):
settings: dict[str, Any] = {}
cfg = model._get_session_config(settings)
Expand All @@ -657,6 +671,28 @@ def test_session_config_defaults_audio_formats_when_not_call(self, model):
assert cfg.audio.output.format is not None
assert cfg.audio.output.format.type == "audio/pcm"

def test_audio_transcription_helper_accepts_new_models(self):
args = {"transcription": {"model": "gpt-4o-mini-transcribe-2025-12-15"}}
prepared = _AudioTranscriptionHelper.prepare_audio_input_args(args)
assert prepared is not args
transcription = prepared["transcription"]
assert isinstance(transcription, AudioTranscription)
assert transcription.model is not None
assert str(transcription.model) == "gpt-4o-mini-transcribe-2025-12-15"

def test_audio_transcription_helper_returns_copy_without_transcription(self):
args = {"format": "pcm16"}
prepared = _AudioTranscriptionHelper.prepare_audio_input_args(args)
assert prepared is not args
assert prepared == args

def test_audio_transcription_helper_raises_on_non_literal_error(self):
# Non-literal validation errors should still surface to the caller.
with pytest.raises(pydantic.ValidationError):
_AudioTranscriptionHelper._coerce_audio_transcription(
{"model": "gpt-4o-mini-transcribe", "language": 123} # invalid language type
)

def test_session_config_preserves_sip_audio_formats(self, model):
model._call_id = "call-123"
settings = {
Expand Down
30 changes: 30 additions & 0 deletions tests/util/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from typing import Literal

import pytest
from pydantic import BaseModel, ValidationError

from agents.util._pydantic import coerce_model_with_literal_fallback


def test_coerce_model_with_literal_fallback_accepts_literal_miss():
class LiteralToyModel(BaseModel):
kind: str
mode: Literal["a", "b"]

obj = coerce_model_with_literal_fallback(
LiteralToyModel,
{"kind": "x", "mode": "c"},
literal_error_locs=[("mode",)],
)
assert isinstance(obj, LiteralToyModel)
assert str(obj.mode) == "c"


def test_coerce_model_with_literal_fallback_propagates_other_errors():
class OtherModel(BaseModel):
field: int

with pytest.raises(ValidationError):
coerce_model_with_literal_fallback(OtherModel, {"field": "oops"}, literal_error_locs=[])