diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index e194a24d0..40672280f 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -10,10 +10,15 @@ class ImageURL(BaseModel): url: str +class AudioURL(BaseModel): + url: str + + class MessageContent(BaseModel): type: str text: Optional[str] = None image_url: Optional[ImageURL] = None + audio_url: Optional[AudioURL] = None class Message(BaseModel): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d6a9c789a..eacd8d67f 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -174,7 +174,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req created_time = int(time.time()) - multimodal_params_dict = {"images": []} + multimodal_params_dict = {"images": [], "audios": []} for message in request.messages: if isinstance(message.content, list): texts = [] @@ -197,6 +197,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req raise ValueError( "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." ) + elif content.type == "audio_url" and content.audio_url is not None: + audio = content.audio_url.url + if audio.startswith("http://") or audio.startswith("https://"): + multimodal_params_dict["audios"].append({"type": "url", "data": audio}) + elif audio.startswith("data:audio"): + data_str = audio.split(";", 1)[1] + if data_str.startswith("base64,"): + data = data_str[7:] + multimodal_params_dict["audios"].append({"type": "base64", "data": data}) + else: + raise ValueError("Unrecognized audio input.") + else: + raise ValueError("Unrecognized audio input. Supports local path, http url, base64.") tools = None if request.tools and request.tool_choice != "none": diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a5..7f16d519a 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,9 +1,15 @@ +import os +import json +from lightllm.server.tokenizer import get_tokenizer +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + tokenizer = None def init_tokenizer(args): global tokenizer - from lightllm.server.tokenizer import get_tokenizer tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) chat_path = args.chat_template @@ -11,6 +17,29 @@ def init_tokenizer(args): with open(chat_path, "r", encoding="utf-8") as f: chat_template_str = f.read() tokenizer.chat_template = chat_template_str + return + + # 如果 tokenizer 目录下存在chat_template.json, 同时不存在 chat_template.jinja, + # 则加载其并赋值给tokenizer 的 chat_template 对象。 + if not os.path.exists(os.path.join(args.model_dir, "chat_template.jinja")) and os.path.exists( + os.path.join(args.model_dir, "chat_template.json") + ): + default_chat_template_path = os.path.join(args.model_dir, "chat_template.json") + try: + with open(default_chat_template_path, "r", encoding="utf-8") as f: + template_data = json.load(f) + if "chat_template" in template_data: + # Set it directly on the tokenizer object so apply_chat_template can use it + if hasattr(tokenizer, "tokenizer"): + # 多模态 tokenizer + tokenizer.tokenizer.chat_template = template_data["chat_template"] + else: + tokenizer.chat_template = template_data["chat_template"] + + logger.info(f"Loaded chat_template.json from {default_chat_template_path}") + except Exception as e: + logger.warning(f"Failed to load chat_template.json from {default_chat_template_path}: {e}") + return async def build_prompt(request, tools) -> str: diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index d957c7649..09a07455b 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -67,6 +67,15 @@ def to_dict(self): ret["start_index_in_embed_cache"] = self.start_index_in_embed_cache return ret + def to_origin_dict(self): + """ + 将内容转换为原始请求的形式,主要用于请求转发 + """ + ret = {} + ret["type"] = self._type + ret["data"] = self._data + return ret + class ImageItem: def __init__(self, **kwargs): @@ -173,4 +182,5 @@ def to_origin_dict(self): """ ret = {} ret["images"] = [i.to_origin_dict() for i in self.images] + ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret