Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ public AgentRunResponse(ChatResponse response)
this.RawRepresentation = response;
this.ResponseId = response.ResponseId;
this.Usage = response.Usage;
this.ContinuationToken = response.ContinuationToken;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ public AgentRunResponseUpdate(ChatResponseUpdate chatResponseUpdate)
this.RawRepresentation = chatResponseUpdate;
this.ResponseId = chatResponseUpdate.ResponseId;
this.Role = chatResponseUpdate.Role;
this.ContinuationToken = chatResponseUpdate.ContinuationToken;
}

/// <summary>Gets or sets the name of the author of the response update.</summary>
Expand Down
4 changes: 4 additions & 0 deletions dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;

namespace Microsoft.Agents.AI;

Expand Down Expand Up @@ -68,6 +70,8 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(ChatClientAgentThread.ThreadState))]
[JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))]
[JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))]
[JsonSerializable(typeof(IReadOnlyCollection<ChatMessage>))]
[JsonSerializable(typeof(IReadOnlyCollection<ChatResponseUpdate>))]

[ExcludeFromCodeCoverage]
internal sealed partial class JsonContext : JsonSerializerContext;
Expand Down
117 changes: 91 additions & 26 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ static Task<ChatResponse> GetResponseAsync(IChatClient chatClient, List<ChatMess

static AgentRunResponse CreateResponse(ChatResponse chatResponse)
{
return new AgentRunResponse(chatResponse);
return new AgentRunResponse(chatResponse)
{
ContinuationToken = WrapContinuationToken(chatResponse.ContinuationToken)
};
}

return this.RunCoreAsync(GetResponseAsync, CreateResponse, messages, thread, options, cancellationToken);
Expand Down Expand Up @@ -204,7 +207,9 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
(ChatClientAgentThread safeThread, ChatOptions? chatOptions, List<ChatMessage> inputMessagesForChatClient, IList<ChatMessage>? aiContextProviderMessages) =
await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false);

ValidateStreamResumptionAllowed(chatOptions?.ContinuationToken, safeThread);
var continuationToken = ParseContinuationToken(options?.ContinuationToken);

ValidateStreamResumptionAllowed(continuationToken, safeThread);

var chatClient = this.ChatClient;

Expand All @@ -214,7 +219,7 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync

this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, loggingAgentName, this._chatClientType);

List<ChatResponseUpdate> responseUpdates = [];
List<ChatResponseUpdate> responseUpdates = GetResponseUpdates(continuationToken);

IAsyncEnumerator<ChatResponseUpdate> responseUpdatesEnumerator;

Expand All @@ -225,7 +230,7 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}

Expand All @@ -239,7 +244,7 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}

Expand All @@ -251,7 +256,12 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
update.AuthorName ??= this.Name;

responseUpdates.Add(update);
yield return new(update) { AgentId = this.Id };

yield return new(update)
{
AgentId = this.Id,
ContinuationToken = WrapContinuationToken(update.ContinuationToken, GetInputMessages(inputMessages, continuationToken), responseUpdates)
};
}

try
Expand All @@ -260,7 +270,7 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
}
catch (Exception ex)
{
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, inputMessages, aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfFailureAsync(safeThread, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false);
throw;
}
}
Expand All @@ -272,10 +282,10 @@ public override async IAsyncEnumerable<AgentRunResponseUpdate> RunStreamingAsync
this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId);

// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);
await NotifyMessageStoreOfNewMessagesAsync(safeThread, GetInputMessages(inputMessages, continuationToken).Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false);

// Notify the AIContextProvider of all new messages.
await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
await NotifyAIContextProviderOfSuccessAsync(safeThread, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc/>
Expand Down Expand Up @@ -382,6 +392,8 @@ private async Task<TAgentRunResponse> RunCoreAsync<TAgentRunResponse, TChatClien
(ChatClientAgentThread safeThread, ChatOptions? chatOptions, List<ChatMessage> inputMessagesForChatClient, IList<ChatMessage>? aiContextProviderMessages) =
await this.PrepareThreadAndMessagesAsync(thread, inputMessages, options, cancellationToken).ConfigureAwait(false);

ValidatePollingAllowed(chatOptions?.ContinuationToken, safeThread);

var chatClient = this.ChatClient;

chatClient = ApplyRunOptionsTransformations(options, chatClient);
Expand Down Expand Up @@ -583,12 +595,16 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider

static ChatOptions? ApplyBackgroundResponsesProperties(ChatOptions? chatOptions, AgentRunOptions? agentRunOptions)
{
// If any of the background response properties are set in the run options, we should apply both to the chat options.
if (agentRunOptions?.AllowBackgroundResponses is not null || agentRunOptions?.ContinuationToken is not null)
if (agentRunOptions?.AllowBackgroundResponses is not null)
{
chatOptions ??= new ChatOptions();
chatOptions.AllowBackgroundResponses = agentRunOptions.AllowBackgroundResponses;
chatOptions.ContinuationToken = agentRunOptions.ContinuationToken;
}

if ((agentRunOptions?.ContinuationToken ?? chatOptions?.ContinuationToken) is { } continuationToken)
{
chatOptions ??= new ChatOptions();
chatOptions.ContinuationToken = ParseContinuationToken(continuationToken)!.InnerToken;
}

return chatOptions;
Expand Down Expand Up @@ -630,11 +646,6 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider
throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token.");
}

if (chatOptions?.ContinuationToken is not null && typedThread.ConversationId is null && typedThread.MessageStore is null)
{
throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs.");
}

List<ChatMessage> inputMessagesForChatClient = [];
IList<ChatMessage>? aiContextProviderMessages = null;

Expand Down Expand Up @@ -739,26 +750,80 @@ private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread t
return Task.CompletedTask;
}

private static void ValidateStreamResumptionAllowed(ResponseContinuationToken? continuationToken, ChatClientAgentThread safeThread)
private static void ValidateStreamResumptionAllowed(ChatClientAgentContinuationToken? continuationToken, ChatClientAgentThread safeThread)
{
if (continuationToken is null)
{
return;
}

// If neither input messages nor response updates are present in the token,
// it means it's an initial run that cannot be resumed.
if (continuationToken.InputMessages is not { Count: > 0 } && continuationToken.ResponseUpdates is not { Count: > 0 })
{
throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs.");
}
}

private static void ValidatePollingAllowed(ResponseContinuationToken? continuationToken, ChatClientAgentThread safeThread)
{
if (continuationToken is null)
{
return;
}

// Streaming resumption is only supported with chat history managed by the agent service because, currently, there's no good solution
// to collect updates received in failed runs and pass them to the last successful run so it can store them to the message store.
if (safeThread.ConversationId is null)
// If neither conversation id nor message store are set on the thread,
// it means it's an initial run that cannot be polled.
if (safeThread.ConversationId is null && safeThread.MessageStore is null)
{
throw new InvalidOperationException("Continuation tokens are not allowed to be used for initial runs.");
}
}

private static ChatClientAgentContinuationToken? ParseContinuationToken(ResponseContinuationToken? continuationToken)
{
return continuationToken is null
? null
: ChatClientAgentContinuationToken.FromToken(continuationToken);
}

private static ChatClientAgentContinuationToken? WrapContinuationToken(ResponseContinuationToken? continuationToken, IReadOnlyCollection<ChatMessage>? inputMessages = null, List<ChatResponseUpdate>? responseUpdates = null)
{
if (continuationToken is null)
{
throw new NotSupportedException("Streaming resumption is only supported when chat history is stored and managed by the underlying AI service.");
return null;
}

// Similarly, streaming resumption is not supported when a context provider is used because, currently, there's no good solution
// to collect updates received in failed runs and pass them to the last successful run so it can notify the context provider of the updates.
if (safeThread.AIContextProvider is not null)
return new(continuationToken)
{
// Save input messages to the continuation token so they can be added to the thread and
// provided to the context provider in the last successful streaming resumption run.
// That's necessary for scenarios where initial streaming run fails and streaming is resumed later.
InputMessages = inputMessages?.Count > 0 ? inputMessages : null,

// Save all updates received so far to the continuation token so they can be provided to the
// message store and context provider in the last successful streaming resumption run.
// That's necessary for scenarios where a streaming run fails after some updates were received.
ResponseUpdates = responseUpdates?.Count > 0 ? responseUpdates : null
};
}

private static IReadOnlyCollection<ChatMessage> GetInputMessages(IReadOnlyCollection<ChatMessage> inputMessages, ChatClientAgentContinuationToken? token)
{
// First, use input messages if provided.
if (inputMessages.Count > 0)
{
throw new NotSupportedException("Using context provider with streaming resumption is not supported.");
return inputMessages;
}

// Fallback to messages saved in the continuation token if available.
return token?.InputMessages ?? [];
}

private static List<ChatResponseUpdate> GetResponseUpdates(ChatClientAgentContinuationToken? token)
{
// Restore any previously received updates from the continuation token.
return token?.ResponseUpdates?.ToList() ?? [];
}

private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent";
Expand Down
Loading
Loading