Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions lightllm/server/api_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing_extensions import deprecated
import uuid

from pydantic import BaseModel, Field, field_validator, model_validator
Expand Down Expand Up @@ -114,7 +115,10 @@ class CompletionRequest(BaseModel):
# prompt: string or tokens
prompt: Union[str, List[str], List[int], List[List[int]]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 8192
max_tokens: Optional[int] = Field(
default=None, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
)
max_completion_tokens: Optional[int] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
Expand Down Expand Up @@ -187,7 +191,10 @@ class ChatCompletionRequest(BaseModel):
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = 8192
max_tokens: Optional[int] = Field(
default=None, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
)
max_completion_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
Expand Down
17 changes: 13 additions & 4 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
"top_p": request.top_p,
"top_k": request.top_k,
"ignore_eos": request.ignore_eos,
"max_new_tokens": request.max_tokens,
"stop_sequences": request.stop,
"n": request.n,
"best_of": request.n,
"add_special_tokens": False,
"seed": request.seed,
}

if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
if request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens
Comment on lines +212 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for setting max_new_tokens gives precedence to the deprecated max_tokens parameter. If both max_completion_tokens and max_tokens are provided, the value from max_tokens will overwrite the one from max_completion_tokens. To ensure the new parameter is prioritized, this logic should be adjusted.

Suggested change
if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
if request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens
if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
elif request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens

if request.stop is not None:
sampling_params_dict["stop_sequences"] = request.stop

# Structured output handling
if request.response_format:
if request.response_format.type == "json_schema":
Expand Down Expand Up @@ -533,13 +538,17 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) ->
"top_p": request.top_p,
"top_k": request.top_k,
"ignore_eos": request.ignore_eos,
"max_new_tokens": request.max_tokens,
"stop_sequences": request.stop,
"n": request.n,
"best_of": request.best_of,
"add_special_tokens": False,
"seed": request.seed,
}
if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
if request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens
Comment on lines +546 to +549
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the chat completions endpoint, the logic here incorrectly gives precedence to the deprecated max_tokens parameter. The new max_completion_tokens parameter should be prioritized to ensure correct behavior and a smooth transition for users.

Suggested change
if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
if request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens
if request.max_completion_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
elif request.max_tokens is not None:
sampling_params_dict["max_new_tokens"] = request.max_tokens

if request.stop is not None:
sampling_params_dict["stop_sequences"] = request.stop

if request.response_format:
if request.response_format.type == "json_schema":
Expand Down
21 changes: 12 additions & 9 deletions lightllm/server/core/objs/py_sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
top_k: int = None, # -1 is for all
ignore_eos: bool = False,
image_max_patch_num: int = -1,
max_new_tokens: int = 16,
max_new_tokens: int = -1,
min_new_tokens: int = 1,
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
Expand Down Expand Up @@ -141,14 +141,6 @@ def verify(self):
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
if self.min_new_tokens > self.max_new_tokens:
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)

if len(self.exponential_decay_length_penalty) != 2:
raise ValueError(
Expand Down Expand Up @@ -201,6 +193,17 @@ def verify(self):

return

def verify_length(self):
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
if self.min_new_tokens > self.max_new_tokens:
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)
return

def _verify_allowed_token_ids(self):
if self.allowed_token_ids is not None:
if (not isinstance(self.allowed_token_ids, list)) or (
Expand Down
12 changes: 7 additions & 5 deletions lightllm/server/core/objs/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs):
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
self.ignore_eos = kwargs.get("ignore_eos", False)
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
self.max_new_tokens = kwargs.get("max_new_tokens", -1)
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
self.group_request_id = kwargs.get("group_request_id", -1)
Expand Down Expand Up @@ -439,6 +439,12 @@ def verify(self):
raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
self._verify_allowed_token_ids()
self._verify_grammar_constraint()

return

def verify_length(self):
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
if self.min_new_tokens < 1:
Expand All @@ -447,10 +453,6 @@ def verify(self):
raise ValueError(
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
)

self._verify_allowed_token_ids()
self._verify_grammar_constraint()

return

def _verify_grammar_constraint(self):
Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
if not prompt_ids:
raise ValueError("prompt_ids is empty")
prompt_tokens = len(prompt_ids)
if sampling_params.max_new_tokens == -1:
sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens
if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
# use long_truncation_mode to truncate long input len req.
if self.args.long_truncation_mode is None:
Expand All @@ -472,6 +474,7 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
assert prompt_tokens == req_input_len
else:
assert False, "error args"
sampling_params.verify_length()

# last repaired
req_total_len = len(prompt_ids) + sampling_params.max_new_tokens
Expand Down