From 034902be5b6c062c94342355faad6d874d18ba13 Mon Sep 17 00:00:00 2001 From: Liu Zhengyun Date: Fri, 16 Jan 2026 12:14:19 +0800 Subject: [PATCH] fix bugs and sync code --- .../ainode/iotdb/ainode/core/constant.py | 2 +- .../core/inference/batcher/basic_batcher.py | 27 ++++++++++++++----- .../core/inference/inference_request.py | 15 +++++------ .../core/inference/inference_request_pool.py | 2 +- .../pool_scheduler/basic_pool_scheduler.py | 11 +++++++- .../ainode/core/manager/inference_manager.py | 2 +- .../ainode/iotdb/ainode/core/manager/utils.py | 4 +-- .../function/tvf/ClassifyTableFunction.java | 2 +- .../function/tvf/ForecastTableFunction.java | 2 +- .../db/queryengine/plan/udf/UDTFForecast.java | 3 +-- .../src/main/thrift/ainode.thrift | 2 +- 11 files changed, 45 insertions(+), 27 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 8a83c98143795..68f64a79afc45 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -53,7 +53,7 @@ # TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { "sundial": 1036 * 1024**2, # 1036 MiB - "timer": 856 * 1024**2, # 856 MiB + "timer_xl": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for inference diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py index 591a0d7c1dd0b..80ff683d9d18d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py @@ -34,6 +34,7 @@ def __init__(self): Args: """ + super().__init__() def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor: """ @@ -46,17 +47,29 @@ def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor: Returns: torch.Tensor: Concatenated input tensor of shape - [sum(req.batch_size), length]. + [sum(req.batch_size), target_count, input_length]. """ if not reqs: raise ValueError("No requests provided to batch_request.") - # Ensure length consistency - length_set = {req.inputs.shape[1] for req in reqs} - if len(length_set) != 1: - raise ValueError( - f"All requests must have the same length, " f"but got {length_set}" - ) + # Ensure shape consistency + first_target_count = reqs[0].target_count + first_input_length = reqs[0].input_length + + for i, req in enumerate(reqs): + if req.target_count != first_target_count: + raise ValueError( + f"All requests must have the same target_count, " + f"but request 0 has {first_target_count} " + f"and request {i} has {req.target_count}" + ) + + if req.input_length != first_input_length: + raise ValueError( + f"All requests must have the same input_length, " + f"but request 0 has {first_input_length} " + f"and request {i} has {req.input_length}" + ) batch_inputs = torch.cat([req.inputs for req in reqs], dim=0) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 93887477aa55f..43380aa1a08d1 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -49,21 +49,20 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.output_length = ( - output_length # Number of time series data points to generate - ) + self.output_length = output_length self.batch_size = inputs.size(0) - self.variable_size = inputs.size(1) + self.target_count = inputs.size(1) + self.input_length = inputs.size(2) self.state = InferenceRequestState.WAITING self.cur_step_idx = 0 # Current write position in the output step index self.assigned_pool_id = -1 # The pool handling this request self.assigned_device_id = -1 # The device handling this request - # Preallocate output buffer [batch_size, max_new_tokens] + # Preallocate output buffer [batch_size, target_count, output_length] self.output_tensor = torch.zeros( - self.batch_size, self.variable_size, output_length, device="cpu" - ) # shape: [batch_size, target_count, predict_length] + self.batch_size, self.target_count, output_length, device="cpu" + ) def mark_running(self): self.state = InferenceRequestState.RUNNING @@ -81,7 +80,7 @@ def write_step_output(self, step_output: torch.Tensor): while step_output.ndim < 3: step_output = step_output.unsqueeze(0) - batch_size, variable_size, step_size = step_output.shape + batch_size, target_count, step_size = step_output.shape end_idx = self.cur_step_idx + step_size if end_idx > self.output_length: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 516c1d07c2c79..dcfa4528fce97 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -115,7 +115,7 @@ def _step(self): grouped_requests = defaultdict(list) for req in all_requests: - key = (req.inputs.shape[1], req.output_length) + key = (req.target_count, req.input_length, req.output_length) grouped_requests[key].append(req) grouped_requests = list(grouped_requests.values()) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 65aa77143939a..591577785fdd5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -20,6 +20,7 @@ import torch +from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.inference.pool_group import PoolGroup from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import ( AbstractPoolScheduler, @@ -51,6 +52,14 @@ def _estimate_shared_pool_size_by_total_mem( Returns: mapping {model_id: pool_num} """ + + # Check if the model supports concurrent forecasting + if new_model_info and new_model_info.model_id not in MODEL_MEM_USAGE_MAP: + logger.error( + f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {new_model_info.model_id} does not support concurrent forecasting." + ) + raise ModelNotExistException(new_model_info.model_id) + # Extract unique model IDs all_models = existing_model_infos + ( [new_model_info] if new_model_info is not None else [] @@ -60,7 +69,7 @@ def _estimate_shared_pool_size_by_total_mem( mem_usages: Dict[str, float] = {} for model_info in all_models: mem_usages[model_info.model_id] = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO ) # Evaluate system resources and get TOTAL memory diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index addcfad6cfb5f..ebbb036a9dca7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -253,7 +253,7 @@ def _run( return resp_cls( get_status(TSStatusCode.SUCCESS_STATUS), - output_list[0] if single_batch else output_list, + [output_list[0]] if single_batch else output_list, ) except Exception as e: diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 45db66e018674..892d4650e152d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -67,9 +67,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] - mem_usage = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO - ) + mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage) if size <= 0: logger.error( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java index 670e019a4b610..34a1a6b223981 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java @@ -377,7 +377,7 @@ private TsBlock classify() { "Error occurred while executing classify:[%s]", resp.getStatus().getMessage()); throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); } - return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + return SERDE.deserialize(resp.forecastResult.get(0)); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index dcb27825e3148..579802c542a41 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -574,7 +574,7 @@ protected TsBlock forecast() { throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); } - TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + TsBlock res = SERDE.deserialize(resp.forecastResult.get(0)); if (res.getValueColumnCount() != inputData.getValueColumnCount()) { throw new IoTDBRuntimeException( String.format( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index ebecf79f5b70b..a6794a5896fb8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -40,7 +40,6 @@ import org.apache.tsfile.read.common.block.column.TsBlockSerde; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -226,7 +225,7 @@ private TsBlock forecast() throws Exception { resp.getStatus().getCode(), resp.getStatus().getMessage()), resp.getStatus().getCode()); } - return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + return serde.deserialize(resp.forecastResult.get(0)); } @Override diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index 1cb585f0323cd..68347b89203ca 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -92,7 +92,7 @@ struct TForecastReq { struct TForecastResp { 1: required common.TSStatus status - 2: optional binary forecastResult + 2: optional list forecastResult } struct TShowModelsReq {