diff --git a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java index b084de860..e193e7686 100644 --- a/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java +++ b/core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java @@ -73,37 +73,12 @@ public Completable compact(Session session, BaseSessionService sessionService) { logger.debug("Running tail retention event compaction for session {}", session.id()); return Maybe.just(session.events()) - .filter(this::shouldCompact) - .flatMap(events -> getCompactionEvents(events)) + .flatMap(this::getCompactionEvents) .flatMap(summarizer::summarizeEvents) .flatMapSingle(e -> sessionService.appendEvent(session, e)) .ignoreElement(); } - private boolean shouldCompact(List events) { - int count = getLatestPromptTokenCount(events).orElse(0); - - // TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not - // available. - if (count <= tokenThreshold) { - logger.debug( - "Skipping compaction. Prompt token count {} is within threshold {}", - count, - tokenThreshold); - return false; - } - return true; - } - - private Optional getLatestPromptTokenCount(List events) { - return Lists.reverse(events).stream() - .map(Event::usageMetadata) - .flatMap(Optional::stream) - .map(GenerateContentResponseUsageMetadata::promptTokenCount) - .flatMap(Optional::stream) - .findFirst(); - } - /** * Identifies events to be compacted based on the tail retention strategy. * @@ -161,8 +136,19 @@ private Optional getLatestPromptTokenCount(List events) { * together. The new compaction event will cover the range from the start of the included * compaction event (C2, T=1) to the end of the new events (E4, T=4). * + * + * @param events The list of events to process. */ private Maybe> getCompactionEvents(List events) { + Optional count = getLatestPromptTokenCount(events); + if (count.isPresent() && count.get() <= tokenThreshold) { + logger.debug( + "Skipping compaction. Prompt token count {} is within threshold {}", + count.get(), + tokenThreshold); + return Maybe.empty(); + } + long compactionEndTimestamp = Long.MIN_VALUE; Event lastCompactionEvent = null; List eventsToSummarize = new ArrayList<>(); @@ -195,11 +181,6 @@ private Maybe> getCompactionEvents(List events) { } } - // If there are not enough events to summarize, we can return early. - if (eventsToSummarize.size() <= retentionSize) { - return Maybe.empty(); - } - // Add the last compaction event to the list of events to summarize. // This is to ensure that the last compaction event is included in the summary. if (lastCompactionEvent != null) { @@ -214,6 +195,22 @@ private Maybe> getCompactionEvents(List events) { Collections.reverse(eventsToSummarize); + if (count.isEmpty()) { + int estimatedCount = estimateTokenCount(eventsToSummarize); + if (estimatedCount <= tokenThreshold) { + logger.debug( + "Skipping compaction. Estimated prompt token count {} is within threshold {}", + estimatedCount, + tokenThreshold); + return Maybe.empty(); + } + } + + // If there are not enough events to summarize, we can return early. + if (eventsToSummarize.size() <= retentionSize) { + return Maybe.empty(); + } + // Apply retention: keep the most recent 'retentionSize' events out of the summary. // We do this by removing them from the list of events to be summarized. eventsToSummarize @@ -222,6 +219,22 @@ private Maybe> getCompactionEvents(List events) { return Maybe.just(eventsToSummarize); } + private int estimateTokenCount(List events) { + // A common rule of thumb is that one token roughly corresponds to 4 characters of text for + // common English text. + // See https://platform.openai.com/tokenizer + return events.stream().mapToInt(event -> event.stringifyContent().length()).sum() / 4; + } + + private Optional getLatestPromptTokenCount(List events) { + return Lists.reverse(events).stream() + .map(Event::usageMetadata) + .flatMap(Optional::stream) + .map(GenerateContentResponseUsageMetadata::promptTokenCount) + .flatMap(Optional::stream) + .findFirst(); + } + private static boolean isCompactEvent(Event event) { return event.actions() != null && event.actions().compaction().isPresent(); } diff --git a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java index 3260fbe1e..7a4a3ddb9 100644 --- a/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java +++ b/core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java @@ -75,9 +75,13 @@ public void constructor_negativeRetentionSize_throwsException() { } @Test - // TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is - // not available. - public void compaction_skippedWhenTokenUsageMissing() { + public void compaction_skippedWhenEstimatedTokenUsageBelowThreshold() { + // Threshold is 100. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 <= 100 -> Skip. EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100); ImmutableList events = ImmutableList.of( @@ -92,6 +96,34 @@ public void compaction_skippedWhenTokenUsageMissing() { verify(mockSessionService, never()).appendEvent(any(), any()); } + @Test + public void compaction_happensWhenEstimatedTokenUsageAboveThreshold() { + // Threshold is 2. + // Event1: "Event1" -> length 6. + // Retain1: "Retain1" -> length 7. + // Retain2: "Retain2" -> length 7. + // Total eligible for estimation (including retained ones as per current logic): + // Logic: getCompactionEvents returns [Event1, Retain1, Retain2] for estimation. + // Total length = 20. Estimated tokens = 20 / 4 = 5. + // 5 > 2 -> Compact. + EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 2); + ImmutableList events = + ImmutableList.of( + createEvent(1, "Event1"), + createEvent(2, "Retain1"), + createEvent(3, "Retain2")); // No usage metadata + Session session = Session.builder("id").events(events).build(); + Event summaryEvent = createEvent(4, "Summary"); + + when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent)); + when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent)); + + compactor.compact(session, mockSessionService).blockingSubscribe(); + + verify(mockSummarizer).summarizeEvents(any()); + verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent)); + } + @Test public void compaction_skippedWhenTokenUsageBelowThreshold() { // Threshold is 300, usage is 200.