Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iotdb-core/ainode/iotdb/ainode/core/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
# TODO: Should be optimized
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
"sundial": 1036 * 1024**2, # 1036 MiB
"timer": 856 * 1024**2, # 856 MiB
"timer_xl": 856 * 1024**2, # 856 MiB
} # the memory usage of each model in bytes

AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for inference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self):
Args:

"""
super().__init__()

def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor:
"""
Expand All @@ -46,17 +47,29 @@ def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor:

Returns:
torch.Tensor: Concatenated input tensor of shape
[sum(req.batch_size), length].
[sum(req.batch_size), target_count, input_length].
"""
if not reqs:
raise ValueError("No requests provided to batch_request.")

# Ensure length consistency
length_set = {req.inputs.shape[1] for req in reqs}
if len(length_set) != 1:
raise ValueError(
f"All requests must have the same length, " f"but got {length_set}"
)
# Ensure shape consistency
first_target_count = reqs[0].target_count
first_input_length = reqs[0].input_length

for i, req in enumerate(reqs):
if req.target_count != first_target_count:
raise ValueError(
f"All requests must have the same target_count, "
f"but request 0 has {first_target_count} "
f"and request {i} has {req.target_count}"
)

if req.input_length != first_input_length:
raise ValueError(
f"All requests must have the same input_length, "
f"but request 0 has {first_input_length} "
f"and request {i} has {req.input_length}"
)

batch_inputs = torch.cat([req.inputs for req in reqs], dim=0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,20 @@ def __init__(
self.model_id = model_id
self.inputs = inputs
self.infer_kwargs = infer_kwargs
self.output_length = (
output_length # Number of time series data points to generate
)
self.output_length = output_length

self.batch_size = inputs.size(0)
self.variable_size = inputs.size(1)
self.target_count = inputs.size(1)
self.input_length = inputs.size(2)
self.state = InferenceRequestState.WAITING
self.cur_step_idx = 0 # Current write position in the output step index
self.assigned_pool_id = -1 # The pool handling this request
self.assigned_device_id = -1 # The device handling this request

# Preallocate output buffer [batch_size, max_new_tokens]
# Preallocate output buffer [batch_size, target_count, output_length]
self.output_tensor = torch.zeros(
self.batch_size, self.variable_size, output_length, device="cpu"
) # shape: [batch_size, target_count, predict_length]
self.batch_size, self.target_count, output_length, device="cpu"
)

def mark_running(self):
self.state = InferenceRequestState.RUNNING
Expand All @@ -81,7 +80,7 @@ def write_step_output(self, step_output: torch.Tensor):
while step_output.ndim < 3:
step_output = step_output.unsqueeze(0)

batch_size, variable_size, step_size = step_output.shape
batch_size, target_count, step_size = step_output.shape
end_idx = self.cur_step_idx + step_size

if end_idx > self.output_length:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _step(self):

grouped_requests = defaultdict(list)
for req in all_requests:
key = (req.inputs.shape[1], req.output_length)
key = (req.target_count, req.input_length, req.output_length)
grouped_requests[key].append(req)
grouped_requests = list(grouped_requests.values())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch

from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.inference.pool_group import PoolGroup
from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import (
AbstractPoolScheduler,
Expand Down Expand Up @@ -51,6 +52,14 @@ def _estimate_shared_pool_size_by_total_mem(
Returns:
mapping {model_id: pool_num}
"""

# Check if the model supports concurrent forecasting
if new_model_info and new_model_info.model_id not in MODEL_MEM_USAGE_MAP:
logger.error(
f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {new_model_info.model_id} does not support concurrent forecasting."
)
raise ModelNotExistException(new_model_info.model_id)

# Extract unique model IDs
all_models = existing_model_infos + (
[new_model_info] if new_model_info is not None else []
Expand All @@ -60,7 +69,7 @@ def _estimate_shared_pool_size_by_total_mem(
mem_usages: Dict[str, float] = {}
for model_info in all_models:
mem_usages[model_info.model_id] = (
MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO
MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO
)

# Evaluate system resources and get TOTAL memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _run(

return resp_cls(
get_status(TSStatusCode.SUCCESS_STATUS),
output_list[0] if single_batch else output_list,
[output_list[0]] if single_batch else output_list,
)

except Exception as e:
Expand Down
4 changes: 1 addition & 3 deletions iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int:
system_res = evaluate_system_resources(device)
free_mem = system_res["free_mem"]

mem_usage = (
MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO
)
mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO
size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage)
if size <= 0:
logger.error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ private TsBlock classify() {
"Error occurred while executing classify:[%s]", resp.getStatus().getMessage());
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}
return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
return SERDE.deserialize(resp.forecastResult.get(0));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ protected TsBlock forecast() {
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}

TsBlock res = SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
TsBlock res = SERDE.deserialize(resp.forecastResult.get(0));
if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
throw new IoTDBRuntimeException(
String.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.tsfile.read.common.block.column.TsBlockSerde;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
Expand Down Expand Up @@ -226,7 +225,7 @@ private TsBlock forecast() throws Exception {
resp.getStatus().getCode(), resp.getStatus().getMessage()),
resp.getStatus().getCode());
}
return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
return serde.deserialize(resp.forecastResult.get(0));
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct TForecastReq {

struct TForecastResp {
1: required common.TSStatus status
2: optional binary forecastResult
2: optional list<binary> forecastResult
}

struct TShowModelsReq {
Expand Down
Loading