Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,27 @@ async def _persist_interrupted_turn(
)

try:
topic_summary = None
if not context.query_request.conversation_id:
should_generate = context.query_request.generate_topic_summary
if should_generate:
try:
logger.debug(
"Generating topic summary for interrupted new conversation"
)
topic_summary = await get_topic_summary(
context.query_request.query,
context.client,
responses_params.model,
)
Comment on lines +383 to +394
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't put an unbounded summary RPC in front of the interrupt ack.

get_topic_summary() does another Llama Stack request (src/utils/responses.py:114-148), and this branch runs before Line 542 yields the "interrupted" event. If that backend is slow or degraded, a user-initiated cancel now waits on best-effort metadata instead of closing promptly. Please bound this call with a short timeout or move it off the response path.

⏱️ Example way to cap the cancel-path latency
                 try:
                     logger.debug(
                         "Generating topic summary for interrupted new conversation"
                     )
-                    topic_summary = await get_topic_summary(
-                        context.query_request.query,
-                        context.client,
-                        responses_params.model,
-                    )
+                    topic_summary = await asyncio.wait_for(
+                        get_topic_summary(
+                            context.query_request.query,
+                            context.client,
+                            responses_params.model,
+                        ),
+                        timeout=2,
+                    )
+                except TimeoutError as e:
+                    logger.warning(
+                        "Timed out generating topic summary for interrupted turn, "
+                        "request %s: %s",
+                        context.request_id,
+                        e,
+                    )
                 except Exception as e:  # pylint: disable=broad-except
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/app/endpoints/streaming_query.py` around lines 383 - 394, This branch
calls get_topic_summary(...) synchronously before yielding the "interrupted"
event, which can block interrupt acknowledgements; wrap the call with a short
timeout (e.g., use asyncio.wait_for around get_topic_summary) and catch
asyncio.TimeoutError so a slow backend doesn't delay the interrupt path, or
alternatively offload it via asyncio.create_task to run in background and never
await it on the critical path; update the code around
context.query_request.conversation_id and should_generate to use the bounded
call or background task and ensure errors/timeouts are logged but do not prevent
emitting the interrupted event.

except Exception as e: # pylint: disable=broad-except
logger.warning(
"Failed to generate topic summary for interrupted turn, "
"request %s: %s",
context.request_id,
e,
)

completed_at = datetime.datetime.now(datetime.UTC).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
Expand All @@ -391,7 +412,7 @@ async def _persist_interrupted_turn(
summary=turn_summary,
query=context.query_request.query,
skip_userid_check=context.skip_userid_check,
topic_summary=None,
topic_summary=topic_summary,
)
except Exception: # pylint: disable=broad-except
logger.exception(
Expand Down
214 changes: 208 additions & 6 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,19 +1388,22 @@ async def mock_generator() -> AsyncIterator[str]:
yield "data: token\n\n"
raise asyncio.CancelledError()

existing_conv_id = "123e4567-e89b-12d3-a456-426614174000"
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
mock_context.conversation_id = "conv_123"
mock_context.conversation_id = existing_conv_id
mock_context.user_id = "user_123"
mock_context.query_request = QueryRequest(
query="test", media_type=MEDIA_TYPE_JSON
query="test",
media_type=MEDIA_TYPE_JSON,
conversation_id=existing_conv_id, # Existing conversation: no topic summary
) # pyright: ignore[reportCallIssue]
mock_context.started_at = "2024-01-01T00:00:00Z"
mock_context.skip_userid_check = False
mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient)

mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
mock_responses_params.model = "provider1/model1"
mock_responses_params.conversation = "conv_123"
mock_responses_params.conversation = existing_conv_id
mock_responses_params.input = "test"

mock_turn_summary = TurnSummary()
Expand All @@ -1417,7 +1420,7 @@ async def mock_generator() -> AsyncIterator[str]:
new_callable=mocker.AsyncMock,
)

test_request_id = "123e4567-e89b-12d3-a456-426614174000"
test_request_id = "223e4567-e89b-12d3-a456-426614174000"
mock_context.request_id = test_request_id

result = []
Expand All @@ -1436,21 +1439,220 @@ async def mock_generator() -> AsyncIterator[str]:

append_turn_mock.assert_called_once_with(
mock_context.client,
"conv_123",
existing_conv_id,
"test",
"You interrupted this request.",
)
store_query_results_mock.assert_called_once()
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["user_id"] == "user_123"
assert call_kwargs["conversation_id"] == "conv_123"
assert call_kwargs["conversation_id"] == existing_conv_id
assert call_kwargs["summary"].llm_response == "You interrupted this request."
assert call_kwargs["topic_summary"] is None

isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with(
test_request_id
)

@pytest.mark.asyncio
async def test_generate_response_cancelled_persists_topic_summary_for_new_conversation(
self,
mocker: MockerFixture,
isolate_stream_interrupt_registry: Any,
) -> None:
"""Test cancelled stream persists topic_summary when generate_topic_summary is True."""

async def mock_generator() -> AsyncIterator[str]:
yield "data: token\n\n"
raise asyncio.CancelledError()

test_request_id = "123e4567-e89b-12d3-a456-426614174001"
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
mock_context.conversation_id = "conv_new_456"
mock_context.user_id = "user_123"
mock_context.query_request = QueryRequest(
query="What is Kubernetes?",
media_type=MEDIA_TYPE_JSON,
conversation_id=None, # New conversation
generate_topic_summary=True,
) # pyright: ignore[reportCallIssue]
mock_context.started_at = "2024-01-01T00:00:00Z"
mock_context.skip_userid_check = False
mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient)
mock_context.request_id = test_request_id

mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
mock_responses_params.model = "provider1/model1"
mock_responses_params.conversation = "conv_new_456"
mock_responses_params.input = "What is Kubernetes?"

mock_turn_summary = TurnSummary()
mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5)

mocker.patch("app.endpoints.streaming_query.consume_query_tokens")
get_topic_summary_mock = mocker.patch(
"app.endpoints.streaming_query.get_topic_summary",
new=mocker.AsyncMock(return_value="Kubernetes container orchestration"),
)
store_query_results_mock = mocker.patch(
"app.endpoints.streaming_query.store_query_results"
)
mocker.patch(
"app.endpoints.streaming_query.append_turn_to_conversation",
new_callable=mocker.AsyncMock,
)

result = []
async for item in generate_response(
mock_generator(),
mock_context,
mock_responses_params,
mock_turn_summary,
):
result.append(item)

assert any('"event": "interrupted"' in item for item in result)
get_topic_summary_mock.assert_called_once_with(
"What is Kubernetes?",
mock_context.client,
"provider1/model1",
)
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["topic_summary"] == "Kubernetes container orchestration"
isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with(
test_request_id
)

@pytest.mark.asyncio
async def test_generate_response_cancelled_topic_summary_none_when_get_fails(
self,
mocker: MockerFixture,
isolate_stream_interrupt_registry: Any,
) -> None:
"""Test cancelled stream persists with topic_summary=None when get_topic_summary raises."""

async def mock_generator() -> AsyncIterator[str]:
yield "data: token\n\n"
raise asyncio.CancelledError()

test_request_id = "123e4567-e89b-12d3-a456-426614174001"
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
mock_context.conversation_id = "conv_new_456"
mock_context.user_id = "user_123"
mock_context.query_request = QueryRequest(
query="What is Kubernetes?",
media_type=MEDIA_TYPE_JSON,
conversation_id=None, # New conversation
generate_topic_summary=True,
) # pyright: ignore[reportCallIssue]
mock_context.started_at = "2024-01-01T00:00:00Z"
mock_context.skip_userid_check = False
mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient)
mock_context.request_id = test_request_id

mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
mock_responses_params.model = "provider1/model1"
mock_responses_params.conversation = "conv_new_456"
mock_responses_params.input = "What is Kubernetes?"

mock_turn_summary = TurnSummary()
mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5)

mocker.patch("app.endpoints.streaming_query.consume_query_tokens")
mocker.patch(
"app.endpoints.streaming_query.get_topic_summary",
new=mocker.AsyncMock(side_effect=Exception("err")),
)
store_query_results_mock = mocker.patch(
"app.endpoints.streaming_query.store_query_results"
)
mocker.patch(
"app.endpoints.streaming_query.append_turn_to_conversation",
new_callable=mocker.AsyncMock,
)

result = []
async for item in generate_response(
mock_generator(),
mock_context,
mock_responses_params,
mock_turn_summary,
):
result.append(item)

assert any('"event": "interrupted"' in item for item in result)
store_query_results_mock.assert_called_once()
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["topic_summary"] is None
isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with(
test_request_id
)

@pytest.mark.asyncio
async def test_generate_response_cancelled_topic_summary_none_when_generate_disabled(
self,
mocker: MockerFixture,
isolate_stream_interrupt_registry: Any,
) -> None:
"""Test cancelled stream uses topic_summary=None when generate_topic_summary is False."""

async def mock_generator() -> AsyncIterator[str]:
yield "data: token\n\n"
raise asyncio.CancelledError()

test_request_id = "123e4567-e89b-12d3-a456-426614174002"
mock_context = mocker.Mock(spec=ResponseGeneratorContext)
mock_context.conversation_id = "conv_new_789"
mock_context.user_id = "user_123"
mock_context.query_request = QueryRequest(
query="What is Docker?",
media_type=MEDIA_TYPE_JSON,
conversation_id=None, # New conversation
generate_topic_summary=False, # Explicitly disabled
) # pyright: ignore[reportCallIssue]
mock_context.started_at = "2024-01-01T00:00:00Z"
mock_context.skip_userid_check = False
mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient)
mock_context.request_id = test_request_id

mock_responses_params = mocker.Mock(spec=ResponsesApiParams)
mock_responses_params.model = "provider1/model1"
mock_responses_params.conversation = "conv_new_789"
mock_responses_params.input = "What is Docker?"

mock_turn_summary = TurnSummary()
mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5)

mocker.patch("app.endpoints.streaming_query.consume_query_tokens")
get_topic_summary_mock = mocker.patch(
"app.endpoints.streaming_query.get_topic_summary",
new=mocker.AsyncMock(return_value="Docker containerization"),
)
store_query_results_mock = mocker.patch(
"app.endpoints.streaming_query.store_query_results"
)
mocker.patch(
"app.endpoints.streaming_query.append_turn_to_conversation",
new_callable=mocker.AsyncMock,
)

result = []
async for item in generate_response(
mock_generator(),
mock_context,
mock_responses_params,
mock_turn_summary,
):
result.append(item)

assert any('"event": "interrupted"' in item for item in result)
get_topic_summary_mock.assert_not_called()
call_kwargs = store_query_results_mock.call_args[1]
assert call_kwargs["topic_summary"] is None
isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with(
test_request_id
)

@pytest.mark.asyncio
async def test_generate_response_cancelled_stores_results_when_append_fails(
self,
Expand Down
Loading