diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 250a29b89..13c2048f4 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -379,6 +379,27 @@ async def _persist_interrupted_turn( ) try: + topic_summary = None + if not context.query_request.conversation_id: + should_generate = context.query_request.generate_topic_summary + if should_generate: + try: + logger.debug( + "Generating topic summary for interrupted new conversation" + ) + topic_summary = await get_topic_summary( + context.query_request.query, + context.client, + responses_params.model, + ) + except Exception as e: # pylint: disable=broad-except + logger.warning( + "Failed to generate topic summary for interrupted turn, " + "request %s: %s", + context.request_id, + e, + ) + completed_at = datetime.datetime.now(datetime.UTC).strftime( "%Y-%m-%dT%H:%M:%SZ" ) @@ -391,7 +412,7 @@ async def _persist_interrupted_turn( summary=turn_summary, query=context.query_request.query, skip_userid_check=context.skip_userid_check, - topic_summary=None, + topic_summary=topic_summary, ) except Exception: # pylint: disable=broad-except logger.exception( diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 3e0670e94..99dee264e 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1388,11 +1388,14 @@ async def mock_generator() -> AsyncIterator[str]: yield "data: token\n\n" raise asyncio.CancelledError() + existing_conv_id = "123e4567-e89b-12d3-a456-426614174000" mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.conversation_id = "conv_123" + mock_context.conversation_id = existing_conv_id mock_context.user_id = "user_123" mock_context.query_request = QueryRequest( - query="test", media_type=MEDIA_TYPE_JSON + query="test", + media_type=MEDIA_TYPE_JSON, + conversation_id=existing_conv_id, # Existing conversation: no topic summary ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False @@ -1400,7 +1403,7 @@ async def mock_generator() -> AsyncIterator[str]: mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" - mock_responses_params.conversation = "conv_123" + mock_responses_params.conversation = existing_conv_id mock_responses_params.input = "test" mock_turn_summary = TurnSummary() @@ -1417,7 +1420,7 @@ async def mock_generator() -> AsyncIterator[str]: new_callable=mocker.AsyncMock, ) - test_request_id = "123e4567-e89b-12d3-a456-426614174000" + test_request_id = "223e4567-e89b-12d3-a456-426614174000" mock_context.request_id = test_request_id result = [] @@ -1436,14 +1439,14 @@ async def mock_generator() -> AsyncIterator[str]: append_turn_mock.assert_called_once_with( mock_context.client, - "conv_123", + existing_conv_id, "test", "You interrupted this request.", ) store_query_results_mock.assert_called_once() call_kwargs = store_query_results_mock.call_args[1] assert call_kwargs["user_id"] == "user_123" - assert call_kwargs["conversation_id"] == "conv_123" + assert call_kwargs["conversation_id"] == existing_conv_id assert call_kwargs["summary"].llm_response == "You interrupted this request." assert call_kwargs["topic_summary"] is None @@ -1451,6 +1454,205 @@ async def mock_generator() -> AsyncIterator[str]: test_request_id ) + @pytest.mark.asyncio + async def test_generate_response_cancelled_persists_topic_summary_for_new_conversation( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test cancelled stream persists topic_summary when generate_topic_summary is True.""" + + async def mock_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + raise asyncio.CancelledError() + + test_request_id = "123e4567-e89b-12d3-a456-426614174001" + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_new_456" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="What is Kubernetes?", + media_type=MEDIA_TYPE_JSON, + conversation_id=None, # New conversation + generate_topic_summary=True, + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_context.request_id = test_request_id + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_new_456" + mock_responses_params.input = "What is Kubernetes?" + + mock_turn_summary = TurnSummary() + mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) + + mocker.patch("app.endpoints.streaming_query.consume_query_tokens") + get_topic_summary_mock = mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + new=mocker.AsyncMock(return_value="Kubernetes container orchestration"), + ) + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + ) + + result = [] + async for item in generate_response( + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + assert any('"event": "interrupted"' in item for item in result) + get_topic_summary_mock.assert_called_once_with( + "What is Kubernetes?", + mock_context.client, + "provider1/model1", + ) + call_kwargs = store_query_results_mock.call_args[1] + assert call_kwargs["topic_summary"] == "Kubernetes container orchestration" + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + + @pytest.mark.asyncio + async def test_generate_response_cancelled_topic_summary_none_when_get_fails( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test cancelled stream persists with topic_summary=None when get_topic_summary raises.""" + + async def mock_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + raise asyncio.CancelledError() + + test_request_id = "123e4567-e89b-12d3-a456-426614174001" + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_new_456" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="What is Kubernetes?", + media_type=MEDIA_TYPE_JSON, + conversation_id=None, # New conversation + generate_topic_summary=True, + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_context.request_id = test_request_id + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_new_456" + mock_responses_params.input = "What is Kubernetes?" + + mock_turn_summary = TurnSummary() + mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) + + mocker.patch("app.endpoints.streaming_query.consume_query_tokens") + mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + new=mocker.AsyncMock(side_effect=Exception("err")), + ) + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + ) + + result = [] + async for item in generate_response( + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + assert any('"event": "interrupted"' in item for item in result) + store_query_results_mock.assert_called_once() + call_kwargs = store_query_results_mock.call_args[1] + assert call_kwargs["topic_summary"] is None + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + + @pytest.mark.asyncio + async def test_generate_response_cancelled_topic_summary_none_when_generate_disabled( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test cancelled stream uses topic_summary=None when generate_topic_summary is False.""" + + async def mock_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + raise asyncio.CancelledError() + + test_request_id = "123e4567-e89b-12d3-a456-426614174002" + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_new_789" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="What is Docker?", + media_type=MEDIA_TYPE_JSON, + conversation_id=None, # New conversation + generate_topic_summary=False, # Explicitly disabled + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_context.request_id = test_request_id + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_new_789" + mock_responses_params.input = "What is Docker?" + + mock_turn_summary = TurnSummary() + mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) + + mocker.patch("app.endpoints.streaming_query.consume_query_tokens") + get_topic_summary_mock = mocker.patch( + "app.endpoints.streaming_query.get_topic_summary", + new=mocker.AsyncMock(return_value="Docker containerization"), + ) + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + ) + + result = [] + async for item in generate_response( + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + assert any('"event": "interrupted"' in item for item in result) + get_topic_summary_mock.assert_not_called() + call_kwargs = store_query_results_mock.call_args[1] + assert call_kwargs["topic_summary"] is None + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + @pytest.mark.asyncio async def test_generate_response_cancelled_stores_results_when_append_fails( self,