From 5caaa84388129621114435326fcb1fa8620ab99c Mon Sep 17 00:00:00 2001 From: averyzhang Date: Sun, 21 Dec 2025 17:40:01 +0800 Subject: [PATCH 1/2] [FLINK-38825] Introduce AsyncBatchFunction and AsyncBatchWaitOperator --- .../api/datastream/AsyncDataStream.java | 49 +++ .../functions/async/AsyncBatchFunction.java | 81 +++++ .../async/AsyncBatchWaitOperator.java | 230 +++++++++++++ .../async/AsyncBatchWaitOperatorFactory.java | 66 ++++ .../async/AsyncBatchWaitOperatorTest.java | 305 ++++++++++++++++++ 5 files changed, 731 insertions(+) create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncBatchFunction.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java create mode 100644 flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java index a167d8cb17e3b..d237d0cad92e0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java @@ -20,8 +20,10 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction; import org.apache.flink.streaming.api.functions.async.AsyncFunction; import org.apache.flink.streaming.api.functions.async.AsyncRetryStrategy; +import org.apache.flink.streaming.api.operators.async.AsyncBatchWaitOperatorFactory; import org.apache.flink.streaming.api.operators.async.AsyncWaitOperator; import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory; import org.apache.flink.util.Preconditions; @@ -319,4 +321,51 @@ public static SingleOutputStreamOperator orderedWaitWithRetry( OutputMode.ORDERED, asyncRetryStrategy); } + + // ================================================================================ + // Batch Async Operations + // ================================================================================ + + /** + * Adds an AsyncBatchWaitOperator to process elements in batches. The order of output stream + * records may be reordered (unordered mode). + * + *

This method is particularly useful for high-latency inference workloads where batching can + * significantly improve throughput, such as machine learning model inference. + * + *

The operator buffers incoming elements and triggers the async batch function when the + * buffer reaches {@code maxBatchSize}. Remaining elements are flushed when the input ends. + * + * @param in Input {@link DataStream} + * @param func {@link AsyncBatchFunction} to process batches of elements + * @param maxBatchSize Maximum number of elements to batch before triggering async invocation + * @param Type of input record + * @param Type of output record + * @return A new {@link SingleOutputStreamOperator} + */ + public static SingleOutputStreamOperator unorderedWaitBatch( + DataStream in, AsyncBatchFunction func, int maxBatchSize) { + Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0"); + + TypeInformation outTypeInfo = + TypeExtractor.getUnaryOperatorReturnType( + func, + AsyncBatchFunction.class, + 0, + 1, + new int[] {1, 0}, + in.getType(), + Utils.getCallLocationName(), + true); + + // create transform + AsyncBatchWaitOperatorFactory operatorFactory = + new AsyncBatchWaitOperatorFactory<>( + in.getExecutionEnvironment().clean(func), maxBatchSize); + + return in.transform("async batch wait operator", outTypeInfo, operatorFactory); + } + + // TODO: Add orderedWaitBatch in follow-up PR + // TODO: Add time-based batching support in follow-up PR } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncBatchFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncBatchFunction.java new file mode 100644 index 0000000000000..e8a4bf3fd0596 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/AsyncBatchFunction.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.functions.async; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.Function; + +import java.io.Serializable; +import java.util.List; + +/** + * A function to trigger Async I/O operations in batches. + * + *

For each batch of inputs, an async I/O operation can be triggered via {@link + * #asyncInvokeBatch}, and once it has been done, the results can be collected by calling {@link + * ResultFuture#complete}. This is particularly useful for high-latency inference workloads where + * batching can significantly improve throughput. + * + *

Unlike {@link AsyncFunction} which processes one element at a time, this interface allows + * processing multiple elements together, which is beneficial for scenarios like: + * + *

    + *
  • Machine learning model inference where batching improves GPU utilization + *
  • External service calls that support batch APIs + *
  • Database queries that can be batched for efficiency + *
+ * + *

Example usage: + * + *

{@code
+ * public class BatchInferenceFunction implements AsyncBatchFunction {
+ *
+ *   public void asyncInvokeBatch(List inputs, ResultFuture resultFuture) {
+ *     // Submit batch inference request
+ *     CompletableFuture.supplyAsync(() -> {
+ *         List results = modelService.batchInference(inputs);
+ *         return results;
+ *     }).thenAccept(results -> resultFuture.complete(results));
+ *   }
+ * }
+ * }
+ * + * @param The type of the input elements. + * @param The type of the returned elements. + */ +@PublicEvolving +public interface AsyncBatchFunction extends Function, Serializable { + + /** + * Trigger async operation for a batch of stream inputs. + * + *

The implementation should process all inputs in the batch and complete the result future + * with all corresponding outputs. The number of outputs does not need to match the number of + * inputs - it depends on the specific use case. + * + * @param inputs a batch of elements coming from upstream tasks + * @param resultFuture to be completed with the result data for the entire batch + * @throws Exception in case of a user code error. An exception will make the task fail and + * trigger fail-over process. + */ + void asyncInvokeBatch(List inputs, ResultFuture resultFuture) throws Exception; + + // TODO: Add timeout handling in follow-up PR + // TODO: Add open/close lifecycle methods in follow-up PR +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java new file mode 100644 index 0000000000000..9496bef65692d --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.operators.MailboxExecutor; +import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction; +import org.apache.flink.streaming.api.functions.async.CollectionSupplier; +import org.apache.flink.streaming.api.functions.async.ResultFuture; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nonnull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * The {@link AsyncBatchWaitOperator} batches incoming stream records and invokes the {@link + * AsyncBatchFunction} when the batch size reaches the configured maximum. + * + *

This operator implements unordered semantics only - results are emitted as soon as they are + * available, regardless of input order. This is suitable for AI inference workloads where order + * does not matter. + * + *

Key behaviors: + * + *

    + *
  • Buffer incoming records until batch size is reached + *
  • Flush remaining records when end of input is signaled + *
  • Emit all results from the batch function to downstream + *
+ * + *

This is a minimal implementation for the first PR. Future enhancements may include: + * + *

    + *
  • Ordered mode support + *
  • Time-based batching with timers + *
  • Timeout handling + *
  • Retry logic + *
  • Metrics + *
+ * + * @param Input type for the operator. + * @param Output type for the operator. + */ +@Internal +public class AsyncBatchWaitOperator extends AbstractStreamOperator + implements OneInputStreamOperator, BoundedOneInput { + + private static final long serialVersionUID = 1L; + + /** The async batch function to invoke. */ + private final AsyncBatchFunction asyncBatchFunction; + + /** Maximum batch size before triggering async invocation. */ + private final int maxBatchSize; + + /** Buffer for incoming stream records. */ + private transient List buffer; + + /** Mailbox executor for processing async results on the main thread. */ + private final transient MailboxExecutor mailboxExecutor; + + /** Counter for in-flight async operations. */ + private transient int inFlightCount; + + public AsyncBatchWaitOperator( + @Nonnull StreamOperatorParameters parameters, + @Nonnull AsyncBatchFunction asyncBatchFunction, + int maxBatchSize, + @Nonnull MailboxExecutor mailboxExecutor) { + Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0"); + this.asyncBatchFunction = Preconditions.checkNotNull(asyncBatchFunction); + this.maxBatchSize = maxBatchSize; + this.mailboxExecutor = Preconditions.checkNotNull(mailboxExecutor); + + // Setup the operator using parameters + setup(parameters.getContainingTask(), parameters.getStreamConfig(), parameters.getOutput()); + } + + @Override + public void open() throws Exception { + super.open(); + this.buffer = new ArrayList<>(maxBatchSize); + this.inFlightCount = 0; + } + + @Override + public void processElement(StreamRecord element) throws Exception { + buffer.add(element.getValue()); + + if (buffer.size() >= maxBatchSize) { + flushBuffer(); + } + } + + /** Flush the current buffer by invoking the async batch function. */ + private void flushBuffer() throws Exception { + if (buffer.isEmpty()) { + return; + } + + // Create a copy of the buffer and clear it for new incoming elements + List batch = new ArrayList<>(buffer); + buffer.clear(); + + // Increment in-flight counter + inFlightCount++; + + // Create result handler for this batch + BatchResultHandler resultHandler = new BatchResultHandler(); + + // Invoke the async batch function + asyncBatchFunction.asyncInvokeBatch(batch, resultHandler); + } + + @Override + public void endInput() throws Exception { + // Flush any remaining elements in the buffer + flushBuffer(); + + // Wait for all in-flight async operations to complete + while (inFlightCount > 0) { + mailboxExecutor.yield(); + } + } + + @Override + public void close() throws Exception { + super.close(); + } + + /** Returns the current buffer size. Visible for testing. */ + int getBufferSize() { + return buffer != null ? buffer.size() : 0; + } + + /** A handler for the results of a batch async invocation. */ + private class BatchResultHandler implements ResultFuture { + + /** Guard against multiple completions. */ + private final AtomicBoolean completed = new AtomicBoolean(false); + + @Override + public void complete(Collection results) { + Preconditions.checkNotNull( + results, "Results must not be null, use empty collection to emit nothing"); + + if (!completed.compareAndSet(false, true)) { + return; + } + + // Process results in the mailbox thread + mailboxExecutor.execute( + () -> processResults(results), "AsyncBatchWaitOperator#processResults"); + } + + @Override + public void completeExceptionally(Throwable error) { + if (!completed.compareAndSet(false, true)) { + return; + } + + // Signal failure through the containing task + getContainingTask() + .getEnvironment() + .failExternally(new Exception("Async batch operation failed.", error)); + + // Decrement in-flight counter in mailbox thread + mailboxExecutor.execute( + () -> inFlightCount--, "AsyncBatchWaitOperator#decrementInFlight"); + } + + @Override + public void complete(CollectionSupplier supplier) { + Preconditions.checkNotNull( + supplier, "Supplier must not be null, return empty collection to emit nothing"); + + if (!completed.compareAndSet(false, true)) { + return; + } + + mailboxExecutor.execute( + () -> { + try { + processResults(supplier.get()); + } catch (Throwable t) { + getContainingTask() + .getEnvironment() + .failExternally( + new Exception("Async batch operation failed.", t)); + inFlightCount--; + } + }, + "AsyncBatchWaitOperator#processResultsFromSupplier"); + } + + private void processResults(Collection results) { + // Emit all results downstream + for (OUT result : results) { + output.collect(new StreamRecord<>(result)); + } + // Decrement in-flight counter + inFlightCount--; + } + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java new file mode 100644 index 0000000000000..2245ca7945b84 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.async; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.api.operators.legacy.YieldingOperatorFactory; + +/** + * The factory of {@link AsyncBatchWaitOperator}. + * + * @param The input type of the operator + * @param The output type of the operator + */ +@Internal +public class AsyncBatchWaitOperatorFactory extends AbstractStreamOperatorFactory + implements OneInputStreamOperatorFactory, YieldingOperatorFactory { + + private static final long serialVersionUID = 1L; + + private final AsyncBatchFunction asyncBatchFunction; + private final int maxBatchSize; + + public AsyncBatchWaitOperatorFactory( + AsyncBatchFunction asyncBatchFunction, int maxBatchSize) { + this.asyncBatchFunction = asyncBatchFunction; + this.maxBatchSize = maxBatchSize; + this.chainingStrategy = ChainingStrategy.ALWAYS; + } + + @Override + @SuppressWarnings("unchecked") + public > T createStreamOperator( + StreamOperatorParameters parameters) { + AsyncBatchWaitOperator operator = + new AsyncBatchWaitOperator<>( + parameters, asyncBatchFunction, maxBatchSize, getMailboxExecutor()); + return (T) operator; + } + + @Override + public Class getStreamOperatorClass(ClassLoader classLoader) { + return AsyncBatchWaitOperator.class; + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java new file mode 100644 index 0000000000000..5b94045f6d9be --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.async; + +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.runtime.operators.testutils.ExpectedTestException; +import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AsyncBatchWaitOperator}. + * + *

These tests verify: + * + *

    + *
  • Batch size trigger - elements are batched correctly + *
  • Correct result emission - all outputs are emitted downstream + *
  • Exception propagation - errors fail the operator + *
+ */ +@Timeout(value = 100, unit = TimeUnit.SECONDS) +class AsyncBatchWaitOperatorTest { + + /** + * Test that the operator correctly batches elements based on maxBatchSize. + * + *

Input: 5 records with maxBatchSize = 3 + * + *

Expected: 2 batch invocations with sizes [3, 2] + */ + @Test + void testBatchSizeTrigger() throws Exception { + final int maxBatchSize = 3; + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + // Return input * 2 for each element + List results = + inputs.stream().map(i -> i * 2).collect(Collectors.toList()); + resultFuture.complete(results); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + // Process 5 elements + testHarness.processElement(new StreamRecord<>(1, 1L)); + testHarness.processElement(new StreamRecord<>(2, 2L)); + testHarness.processElement(new StreamRecord<>(3, 3L)); + // First batch of 3 should be triggered here + + testHarness.processElement(new StreamRecord<>(4, 4L)); + testHarness.processElement(new StreamRecord<>(5, 5L)); + // Remaining 2 elements in buffer + + testHarness.endInput(); + // Second batch of 2 should be triggered on endInput + + // Verify batch sizes + assertThat(batchSizes).containsExactly(3, 2); + } + } + + /** Test that all results from the batch function are correctly emitted downstream. */ + @Test + void testCorrectResultEmission() throws Exception { + final int maxBatchSize = 3; + + // Function that doubles each input + AsyncBatchFunction function = + (inputs, resultFuture) -> { + List results = + inputs.stream().map(i -> i * 2).collect(Collectors.toList()); + resultFuture.complete(results); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + // Process 5 elements: 1, 2, 3, 4, 5 + for (int i = 1; i <= 5; i++) { + testHarness.processElement(new StreamRecord<>(i, i)); + } + + testHarness.endInput(); + + // Verify outputs: should be 2, 4, 6, 8, 10 + List outputs = + testHarness.getOutput().stream() + .filter(e -> e instanceof StreamRecord) + .map(e -> ((StreamRecord) e).getValue()) + .collect(Collectors.toList()); + + assertThat(outputs).containsExactlyInAnyOrder(2, 4, 6, 8, 10); + } + } + + /** Test that exceptions from the batch function are properly propagated. */ + @Test + void testExceptionPropagation() throws Exception { + final int maxBatchSize = 2; + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + resultFuture.completeExceptionally(new ExpectedTestException()); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + // Process 2 elements to trigger a batch + testHarness.processElement(new StreamRecord<>(1, 1L)); + testHarness.processElement(new StreamRecord<>(2, 2L)); + + // The exception should be propagated - we need to yield to process the async result + // In the test harness, the exception is recorded in the environment + testHarness.endInput(); + + // Verify that the task environment received the exception + assertThat(testHarness.getEnvironment().getActualExternalFailureCause()) + .isPresent() + .get() + .satisfies( + t -> + assertThat(t.getCause()) + .isInstanceOf(ExpectedTestException.class)); + } + } + + /** Test async completion using CompletableFuture. */ + @Test + void testAsyncCompletion() throws Exception { + final int maxBatchSize = 2; + final AtomicInteger invocationCount = new AtomicInteger(0); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + invocationCount.incrementAndGet(); + // Simulate async processing + CompletableFuture.supplyAsync( + () -> + inputs.stream() + .map(i -> i * 3) + .collect(Collectors.toList())) + .thenAccept(resultFuture::complete); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + // Process 4 elements: should trigger 2 batches + for (int i = 1; i <= 4; i++) { + testHarness.processElement(new StreamRecord<>(i, i)); + } + + testHarness.endInput(); + + // Verify invocation count + assertThat(invocationCount.get()).isEqualTo(2); + + // Verify outputs: should be 3, 6, 9, 12 + List outputs = + testHarness.getOutput().stream() + .filter(e -> e instanceof StreamRecord) + .map(e -> ((StreamRecord) e).getValue()) + .collect(Collectors.toList()); + + assertThat(outputs).containsExactlyInAnyOrder(3, 6, 9, 12); + } + } + + /** Test that empty batches are not triggered. */ + @Test + void testEmptyInput() throws Exception { + final int maxBatchSize = 3; + final AtomicInteger invocationCount = new AtomicInteger(0); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + invocationCount.incrementAndGet(); + resultFuture.complete(Collections.emptyList()); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + testHarness.endInput(); + + // No invocations should happen for empty input + assertThat(invocationCount.get()).isEqualTo(0); + assertThat(testHarness.getOutput()).isEmpty(); + } + } + + /** Test that batch function can return fewer or more outputs than inputs. */ + @Test + void testVariableOutputSize() throws Exception { + final int maxBatchSize = 3; + + // Function that returns only one output per batch (aggregation-style) + AsyncBatchFunction function = + (inputs, resultFuture) -> { + int sum = inputs.stream().mapToInt(Integer::intValue).sum(); + resultFuture.complete(Collections.singletonList(sum)); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + // Process 5 elements: 1, 2, 3, 4, 5 + for (int i = 1; i <= 5; i++) { + testHarness.processElement(new StreamRecord<>(i, i)); + } + + testHarness.endInput(); + + // First batch: 1+2+3 = 6, Second batch: 4+5 = 9 + List outputs = + testHarness.getOutput().stream() + .filter(e -> e instanceof StreamRecord) + .map(e -> ((StreamRecord) e).getValue()) + .collect(Collectors.toList()); + + assertThat(outputs).containsExactlyInAnyOrder(6, 9); + } + } + + /** Test single element batch (maxBatchSize = 1). */ + @Test + void testSingleElementBatch() throws Exception { + final int maxBatchSize = 1; + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarness(function, maxBatchSize)) { + + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(1, 1L)); + testHarness.processElement(new StreamRecord<>(2, 2L)); + testHarness.processElement(new StreamRecord<>(3, 3L)); + + testHarness.endInput(); + + // Each element should trigger its own batch + assertThat(batchSizes).containsExactly(1, 1, 1); + } + } + + private static OneInputStreamOperatorTestHarness createTestHarness( + AsyncBatchFunction function, int maxBatchSize) throws Exception { + return new OneInputStreamOperatorTestHarness<>( + new AsyncBatchWaitOperatorFactory<>(function, maxBatchSize), + IntSerializer.INSTANCE); + } +} From 35e5398cbaa74eb3c1f627ebc862b86a69756600 Mon Sep 17 00:00:00 2001 From: averyzhang Date: Sun, 21 Dec 2025 18:16:22 +0800 Subject: [PATCH 2/2] [FLINK-38825] Add time-based batch triggering for AsyncBatchWaitOperator --- .../api/datastream/AsyncDataStream.java | 37 ++- .../async/AsyncBatchWaitOperator.java | 128 +++++++++- .../async/AsyncBatchWaitOperatorFactory.java | 29 ++- .../async/AsyncBatchWaitOperatorTest.java | 220 ++++++++++++++++++ 4 files changed, 405 insertions(+), 9 deletions(-) diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java index d237d0cad92e0..d0fdefd47d4d8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/AsyncDataStream.java @@ -345,6 +345,39 @@ public static SingleOutputStreamOperator orderedWaitWithRetry( */ public static SingleOutputStreamOperator unorderedWaitBatch( DataStream in, AsyncBatchFunction func, int maxBatchSize) { + return unorderedWaitBatch(in, func, maxBatchSize, 0L); + } + + /** + * Adds an AsyncBatchWaitOperator to process elements in batches with timeout support. The order + * of output stream records may be reordered (unordered mode). + * + *

This method is particularly useful for high-latency inference workloads where batching can + * significantly improve throughput, such as machine learning model inference. + * + *

The operator buffers incoming elements and triggers the async batch function when either: + * + *

    + *
  • The buffer reaches {@code maxBatchSize} + *
  • The {@code batchTimeoutMs} has elapsed since the first buffered element (if timeout is + * enabled) + *
+ * + *

Remaining elements are flushed when the input ends. + * + * @param in Input {@link DataStream} + * @param func {@link AsyncBatchFunction} to process batches of elements + * @param maxBatchSize Maximum number of elements to batch before triggering async invocation + * @param batchTimeoutMs Batch timeout in milliseconds; <= 0 means timeout is disabled + * @param Type of input record + * @param Type of output record + * @return A new {@link SingleOutputStreamOperator} + */ + public static SingleOutputStreamOperator unorderedWaitBatch( + DataStream in, + AsyncBatchFunction func, + int maxBatchSize, + long batchTimeoutMs) { Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0"); TypeInformation outTypeInfo = @@ -361,11 +394,11 @@ public static SingleOutputStreamOperator unorderedWaitBatch( // create transform AsyncBatchWaitOperatorFactory operatorFactory = new AsyncBatchWaitOperatorFactory<>( - in.getExecutionEnvironment().clean(func), maxBatchSize); + in.getExecutionEnvironment().clean(func), maxBatchSize, batchTimeoutMs); return in.transform("async batch wait operator", outTypeInfo, operatorFactory); } // TODO: Add orderedWaitBatch in follow-up PR - // TODO: Add time-based batching support in follow-up PR + // TODO: Add event-time based batching support in follow-up PR } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java index 9496bef65692d..1a05b6055d308 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperator.java @@ -20,6 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.operators.MailboxExecutor; +import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; import org.apache.flink.streaming.api.functions.async.AsyncBatchFunction; import org.apache.flink.streaming.api.functions.async.CollectionSupplier; import org.apache.flink.streaming.api.functions.async.ResultFuture; @@ -39,7 +40,8 @@ /** * The {@link AsyncBatchWaitOperator} batches incoming stream records and invokes the {@link - * AsyncBatchFunction} when the batch size reaches the configured maximum. + * AsyncBatchFunction} when the batch size reaches the configured maximum or when the batch timeout + * is reached. * *

This operator implements unordered semantics only - results are emitted as soon as they are * available, regardless of input order. This is suitable for AI inference workloads where order @@ -48,17 +50,26 @@ *

Key behaviors: * *

    - *
  • Buffer incoming records until batch size is reached + *
  • Buffer incoming records until batch size is reached OR timeout expires *
  • Flush remaining records when end of input is signaled *
  • Emit all results from the batch function to downstream *
* - *

This is a minimal implementation for the first PR. Future enhancements may include: + *

Timer lifecycle (when batchTimeoutMs > 0): + * + *

    + *
  • Timer is registered when first element is added to an empty buffer + *
  • Timer fires at: currentBatchStartTime + batchTimeoutMs + *
  • Timer is cleared when batch is flushed (by size, timeout, or end-of-input) + *
  • At most one timer is active at any time + *
+ * + *

Future enhancements may include: * *

    *
  • Ordered mode support - *
  • Time-based batching with timers - *
  • Timeout handling + *
  • Event-time based batching + *
  • Multiple inflight batches *
  • Retry logic *
  • Metrics *
@@ -68,16 +79,25 @@ */ @Internal public class AsyncBatchWaitOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput { + implements OneInputStreamOperator, BoundedOneInput, ProcessingTimeCallback { private static final long serialVersionUID = 1L; + /** Constant indicating timeout is disabled. */ + private static final long NO_TIMEOUT = 0L; + /** The async batch function to invoke. */ private final AsyncBatchFunction asyncBatchFunction; /** Maximum batch size before triggering async invocation. */ private final int maxBatchSize; + /** + * Batch timeout in milliseconds. When positive, a timer is registered to flush the batch after + * this duration since the first buffered element. A value <= 0 disables timeout-based batching. + */ + private final long batchTimeoutMs; + /** Buffer for incoming stream records. */ private transient List buffer; @@ -87,14 +107,54 @@ public class AsyncBatchWaitOperator extends AbstractStreamOperator /** Counter for in-flight async operations. */ private transient int inFlightCount; + // ================================================================================ + // Timer state fields for timeout-based batching + // ================================================================================ + + /** + * The processing time when the current batch started (i.e., when first element was added to + * empty buffer). Used to calculate timer fire time. + */ + private transient long currentBatchStartTime; + + /** Whether a timer is currently registered for the current batch. */ + private transient boolean timerRegistered; + + /** + * Creates an AsyncBatchWaitOperator with size-based batching only (no timeout). + * + * @param parameters Stream operator parameters + * @param asyncBatchFunction The async batch function to invoke + * @param maxBatchSize Maximum batch size before triggering async invocation + * @param mailboxExecutor Mailbox executor for processing async results + */ public AsyncBatchWaitOperator( @Nonnull StreamOperatorParameters parameters, @Nonnull AsyncBatchFunction asyncBatchFunction, int maxBatchSize, @Nonnull MailboxExecutor mailboxExecutor) { + this(parameters, asyncBatchFunction, maxBatchSize, NO_TIMEOUT, mailboxExecutor); + } + + /** + * Creates an AsyncBatchWaitOperator with size-based and optional timeout-based batching. + * + * @param parameters Stream operator parameters + * @param asyncBatchFunction The async batch function to invoke + * @param maxBatchSize Maximum batch size before triggering async invocation + * @param batchTimeoutMs Batch timeout in milliseconds; <= 0 means disabled + * @param mailboxExecutor Mailbox executor for processing async results + */ + public AsyncBatchWaitOperator( + @Nonnull StreamOperatorParameters parameters, + @Nonnull AsyncBatchFunction asyncBatchFunction, + int maxBatchSize, + long batchTimeoutMs, + @Nonnull MailboxExecutor mailboxExecutor) { Preconditions.checkArgument(maxBatchSize > 0, "maxBatchSize must be greater than 0"); this.asyncBatchFunction = Preconditions.checkNotNull(asyncBatchFunction); this.maxBatchSize = maxBatchSize; + this.batchTimeoutMs = batchTimeoutMs; this.mailboxExecutor = Preconditions.checkNotNull(mailboxExecutor); // Setup the operator using parameters @@ -106,23 +166,51 @@ public void open() throws Exception { super.open(); this.buffer = new ArrayList<>(maxBatchSize); this.inFlightCount = 0; + this.currentBatchStartTime = 0L; + this.timerRegistered = false; } @Override public void processElement(StreamRecord element) throws Exception { + // If buffer is empty and timeout is enabled, record batch start time and register timer + if (buffer.isEmpty() && isTimeoutEnabled()) { + currentBatchStartTime = getProcessingTimeService().getCurrentProcessingTime(); + registerBatchTimer(); + } + buffer.add(element.getValue()); + // Size-triggered flush: cancel pending timer and flush if (buffer.size() >= maxBatchSize) { flushBuffer(); } } + /** + * Callback when processing time timer fires. Flushes the buffer if non-empty. + * + * @param timestamp The timestamp for which the timer was registered + */ + @Override + public void onProcessingTime(long timestamp) throws Exception { + // Timer fired - clear timer state first + timerRegistered = false; + + // Flush buffer if non-empty (timeout-triggered flush) + if (!buffer.isEmpty()) { + flushBuffer(); + } + } + /** Flush the current buffer by invoking the async batch function. */ private void flushBuffer() throws Exception { if (buffer.isEmpty()) { return; } + // Clear timer state since we're flushing the batch + clearTimerState(); + // Create a copy of the buffer and clear it for new incoming elements List batch = new ArrayList<>(buffer); buffer.clear(); @@ -153,6 +241,34 @@ public void close() throws Exception { super.close(); } + // ================================================================================ + // Timer management methods + // ================================================================================ + + /** Check if timeout-based batching is enabled. */ + private boolean isTimeoutEnabled() { + return batchTimeoutMs > NO_TIMEOUT; + } + + /** Register a processing time timer for the current batch. */ + private void registerBatchTimer() { + if (!timerRegistered && isTimeoutEnabled()) { + long fireTime = currentBatchStartTime + batchTimeoutMs; + getProcessingTimeService().registerTimer(fireTime, this); + timerRegistered = true; + } + } + + /** + * Clear timer state. Note: We don't explicitly cancel the timer because: 1. The timer callback + * checks buffer state before flushing 2. Cancelling timers has overhead 3. Timer will be + * ignored if buffer is empty when it fires + */ + private void clearTimerState() { + timerRegistered = false; + currentBatchStartTime = 0L; + } + /** Returns the current buffer size. Visible for testing. */ int getBufferSize() { return buffer != null ? buffer.size() : 0; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java index 2245ca7945b84..02839380b0c72 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorFactory.java @@ -39,13 +39,36 @@ public class AsyncBatchWaitOperatorFactory extends AbstractStreamOperat private static final long serialVersionUID = 1L; + /** Constant indicating timeout is disabled. */ + private static final long NO_TIMEOUT = 0L; + private final AsyncBatchFunction asyncBatchFunction; private final int maxBatchSize; + private final long batchTimeoutMs; + /** + * Creates a factory with size-based batching only (no timeout). + * + * @param asyncBatchFunction The async batch function + * @param maxBatchSize Maximum batch size before triggering async invocation + */ public AsyncBatchWaitOperatorFactory( AsyncBatchFunction asyncBatchFunction, int maxBatchSize) { + this(asyncBatchFunction, maxBatchSize, NO_TIMEOUT); + } + + /** + * Creates a factory with size-based and optional timeout-based batching. + * + * @param asyncBatchFunction The async batch function + * @param maxBatchSize Maximum batch size before triggering async invocation + * @param batchTimeoutMs Batch timeout in milliseconds; <= 0 means disabled + */ + public AsyncBatchWaitOperatorFactory( + AsyncBatchFunction asyncBatchFunction, int maxBatchSize, long batchTimeoutMs) { this.asyncBatchFunction = asyncBatchFunction; this.maxBatchSize = maxBatchSize; + this.batchTimeoutMs = batchTimeoutMs; this.chainingStrategy = ChainingStrategy.ALWAYS; } @@ -55,7 +78,11 @@ public > T createStreamOperator( StreamOperatorParameters parameters) { AsyncBatchWaitOperator operator = new AsyncBatchWaitOperator<>( - parameters, asyncBatchFunction, maxBatchSize, getMailboxExecutor()); + parameters, + asyncBatchFunction, + maxBatchSize, + batchTimeoutMs, + getMailboxExecutor()); return (T) operator; } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java index 5b94045f6d9be..f2141759251d1 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncBatchWaitOperatorTest.java @@ -296,10 +296,230 @@ void testSingleElementBatch() throws Exception { } } + // ================================================================================ + // Timeout-based batching tests + // ================================================================================ + + /** + * Test that timeout triggers batch flush even when batch size is not reached. + * + *

maxBatchSize = 10, batchTimeoutMs = 50 + * + *

Send 1 record, advance processing time, expect asyncInvokeBatch called with size 1 + */ + @Test + void testTimeoutFlush() throws Exception { + final int maxBatchSize = 10; + final long batchTimeoutMs = 50L; + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarnessWithTimeout(function, maxBatchSize, batchTimeoutMs)) { + + testHarness.open(); + + // Set initial processing time + testHarness.setProcessingTime(0L); + + // Process 1 element - should start the timer + testHarness.processElement(new StreamRecord<>(1, 1L)); + + // Batch size not reached, no flush yet + assertThat(batchSizes).isEmpty(); + + // Advance processing time past timeout threshold + testHarness.setProcessingTime(batchTimeoutMs + 1); + + // Timer should have fired, triggering batch flush with size 1 + assertThat(batchSizes).containsExactly(1); + + testHarness.endInput(); + } + } + + /** + * Test that size-triggered flush happens before timeout when batch fills up quickly. + * + *

maxBatchSize = 2, batchTimeoutMs = 1 hour (3600000 ms) + * + *

Send 2 records immediately, verify batch is flushed by size, not by timeout + */ + @Test + void testSizeBeatsTimeout() throws Exception { + final int maxBatchSize = 2; + final long batchTimeoutMs = 3600000L; // 1 hour - should never be reached + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarnessWithTimeout(function, maxBatchSize, batchTimeoutMs)) { + + testHarness.open(); + + // Set initial processing time + testHarness.setProcessingTime(0L); + + // Process 2 elements immediately - should trigger batch by size + testHarness.processElement(new StreamRecord<>(1, 1L)); + testHarness.processElement(new StreamRecord<>(2, 2L)); + + // Batch should have been flushed by size (not timeout) + assertThat(batchSizes).containsExactly(2); + + // Even if we advance time, no additional flush should happen since buffer is empty + testHarness.setProcessingTime(batchTimeoutMs + 1); + assertThat(batchSizes).containsExactly(2); + + testHarness.endInput(); + } + } + + /** + * Test that timer is properly reset after batch flush. + * + *

First batch flushed by timeout, second batch starts a new timer. + */ + @Test + void testTimerResetAfterFlush() throws Exception { + final int maxBatchSize = 10; + final long batchTimeoutMs = 100L; + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarnessWithTimeout(function, maxBatchSize, batchTimeoutMs)) { + + testHarness.open(); + + // === First batch === + testHarness.setProcessingTime(0L); + testHarness.processElement(new StreamRecord<>(1, 1L)); + + // Advance time to trigger first timeout flush + testHarness.setProcessingTime(batchTimeoutMs + 1); + assertThat(batchSizes).containsExactly(1); + + // === Second batch === + // Start second batch at time 200 + testHarness.setProcessingTime(200L); + testHarness.processElement(new StreamRecord<>(2, 2L)); + testHarness.processElement(new StreamRecord<>(3, 3L)); + + // No flush yet - batch size not reached + assertThat(batchSizes).containsExactly(1); + + // Advance time to trigger second timeout flush (200 + 100 + 1 = 301) + testHarness.setProcessingTime(301L); + assertThat(batchSizes).containsExactly(1, 2); + + testHarness.endInput(); + } + } + + /** Test timeout with multiple batches interleaving size and timeout triggers. */ + @Test + void testMixedSizeAndTimeoutTriggers() throws Exception { + final int maxBatchSize = 3; + final long batchTimeoutMs = 100L; + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarnessWithTimeout(function, maxBatchSize, batchTimeoutMs)) { + + testHarness.open(); + testHarness.setProcessingTime(0L); + + // First batch: size-triggered + testHarness.processElement(new StreamRecord<>(1, 1L)); + testHarness.processElement(new StreamRecord<>(2, 2L)); + testHarness.processElement(new StreamRecord<>(3, 3L)); + assertThat(batchSizes).containsExactly(3); + + // Second batch: timeout-triggered + testHarness.setProcessingTime(200L); + testHarness.processElement(new StreamRecord<>(4, 4L)); + assertThat(batchSizes).containsExactly(3); // Not flushed yet + + testHarness.setProcessingTime(301L); // 200 + 100 + 1 + assertThat(batchSizes).containsExactly(3, 1); + + // Third batch: size-triggered again + testHarness.setProcessingTime(400L); + testHarness.processElement(new StreamRecord<>(5, 5L)); + testHarness.processElement(new StreamRecord<>(6, 6L)); + testHarness.processElement(new StreamRecord<>(7, 7L)); + assertThat(batchSizes).containsExactly(3, 1, 3); + + testHarness.endInput(); + } + } + + /** Test that timeout is disabled when batchTimeoutMs <= 0. */ + @Test + void testTimeoutDisabled() throws Exception { + final int maxBatchSize = 10; + final long batchTimeoutMs = 0L; // Disabled + final List batchSizes = new CopyOnWriteArrayList<>(); + + AsyncBatchFunction function = + (inputs, resultFuture) -> { + batchSizes.add(inputs.size()); + resultFuture.complete(inputs); + }; + + try (OneInputStreamOperatorTestHarness testHarness = + createTestHarnessWithTimeout(function, maxBatchSize, batchTimeoutMs)) { + + testHarness.open(); + testHarness.setProcessingTime(0L); + + // Process 1 element + testHarness.processElement(new StreamRecord<>(1, 1L)); + + // Advance time significantly - should not trigger flush since timeout is disabled + testHarness.setProcessingTime(1000000L); + assertThat(batchSizes).isEmpty(); + + // Flush happens only on endInput + testHarness.endInput(); + assertThat(batchSizes).containsExactly(1); + } + } + private static OneInputStreamOperatorTestHarness createTestHarness( AsyncBatchFunction function, int maxBatchSize) throws Exception { return new OneInputStreamOperatorTestHarness<>( new AsyncBatchWaitOperatorFactory<>(function, maxBatchSize), IntSerializer.INSTANCE); } + + private static OneInputStreamOperatorTestHarness createTestHarnessWithTimeout( + AsyncBatchFunction function, int maxBatchSize, long batchTimeoutMs) + throws Exception { + return new OneInputStreamOperatorTestHarness<>( + new AsyncBatchWaitOperatorFactory<>(function, maxBatchSize, batchTimeoutMs), + IntSerializer.INSTANCE); + } }