diff --git a/src/art/tinker/cookbook_v/image_processing_utils.py b/src/art/tinker/cookbook_v/image_processing_utils.py
index 1d405176..3d3816b2 100644
--- a/src/art/tinker/cookbook_v/image_processing_utils.py
+++ b/src/art/tinker/cookbook_v/image_processing_utils.py
@@ -28,8 +28,12 @@ def get_image_processor(model_name: str) -> ImageProcessor:
from transformers.models.auto.image_processing_auto import AutoImageProcessor
- processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
- assert processor.is_fast, f"Could not load fast image processor for {model_name}"
+ kwargs: dict[str, Any] = {}
+ if model_name == "moonshotai/Kimi-K2.5":
+ kwargs["trust_remote_code"] = True
+ kwargs["revision"] = "3367c8d1c68584429fab7faf845a32d5195b6ac1"
+
+ processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True, **kwargs)
return processor
diff --git a/src/art/tinker/cookbook_v/renderers/__init__.py b/src/art/tinker/cookbook_v/renderers/__init__.py
index 0e03e8a3..ec776af9 100644
--- a/src/art/tinker/cookbook_v/renderers/__init__.py
+++ b/src/art/tinker/cookbook_v/renderers/__init__.py
@@ -5,6 +5,9 @@
python -m tinker_cookbook.supervised.viz_sft_dataset dataset_path=Tulu3Builder renderer_name=role_colon
"""
+from collections.abc import Callable
+from typing import Any
+
from ..image_processing_utils import ImageProcessor
from ..tokenizer_utils import Tokenizer
@@ -14,15 +17,21 @@
ContentPart,
ImagePart,
Message,
+ # Streaming types
+ MessageDelta,
# Renderer base
RenderContext,
Renderer,
Role,
+ StreamingMessageHeader,
+ StreamingTextDelta,
+ StreamingThinkingDelta,
TextPart,
ThinkingPart,
ToolCall,
ToolSpec,
TrainOnWhat,
+ Utf8TokenDecoder,
# Utility functions
ensure_text,
format_content_as_string,
@@ -35,9 +44,59 @@
from .gpt_oss import GptOssRenderer
from .qwen3 import Qwen3Renderer
+# Global registry for custom renderer factories
+_CUSTOM_RENDERER_REGISTRY: dict[str, Callable[[Tokenizer, Any], Renderer]] = {}
+
+
+def register_renderer(
+ name: str,
+ factory: Callable[[Tokenizer, Any], Renderer],
+) -> None:
+ """Register a custom renderer factory.
+
+ Args:
+ name: The renderer name
+ factory: A callable that takes (tokenizer, image_processor) and returns a Renderer.
+
+ Example:
+ def my_renderer_factory(tokenizer, image_processor=None):
+ return MyCustomRenderer(tokenizer)
+
+ register_renderer("Foo/foo_renderer", my_renderer_factory)
+ """
+ _CUSTOM_RENDERER_REGISTRY[name] = factory
+
+
+def get_registered_renderer_names() -> list[str]:
+ """Return a list of all registered custom renderer names."""
+ return list(_CUSTOM_RENDERER_REGISTRY.keys())
+
+
+def is_renderer_registered(name: str) -> bool:
+ """Check if a renderer name is registered."""
+ return name in _CUSTOM_RENDERER_REGISTRY
+
+
+def unregister_renderer(name: str) -> bool:
+ """Unregister a custom renderer factory.
+
+ Args:
+ name: The renderer name to unregister.
+
+ Returns:
+ True if the renderer was unregistered, False if it wasn't registered.
+ """
+ if name in _CUSTOM_RENDERER_REGISTRY:
+ del _CUSTOM_RENDERER_REGISTRY[name]
+ return True
+ return False
+
def get_renderer(
- name: str, tokenizer: Tokenizer, image_processor: ImageProcessor | None = None
+ name: str,
+ tokenizer: Tokenizer,
+ image_processor: ImageProcessor | None = None,
+ model_name: str | None = None,
) -> Renderer:
"""Factory function to create renderers by name.
@@ -50,16 +109,24 @@ def get_renderer(
- "qwen3_vl_instruct": Qwen3 vision-language instruct (no thinking)
- "qwen3_disable_thinking": Qwen3 with thinking disabled
- "qwen3_instruct": Qwen3 instruct 2507 (no thinking)
+ - "qwen3_5": Qwen3.5 VL with thinking
+ - "qwen3_5_disable_thinking": Qwen3.5 VL with thinking disabled
- "deepseekv3": DeepSeek V3 (defaults to non-thinking mode)
- "deepseekv3_disable_thinking": DeepSeek V3 non-thinking (alias)
- "deepseekv3_thinking": DeepSeek V3 thinking mode
- "kimi_k2": Kimi K2 Thinking format
+ - "kimi_k25": Kimi K2.5 with thinking enabled
+ - "kimi_k25_disable_thinking": Kimi K2.5 with thinking disabled
- "gpt_oss_no_sysprompt": GPT-OSS without system prompt
- "gpt_oss_low_reasoning": GPT-OSS with low reasoning
- "gpt_oss_medium_reasoning": GPT-OSS with medium reasoning
- "gpt_oss_high_reasoning": GPT-OSS with high reasoning
+ - Custom renderers registered via register_renderer()
tokenizer: The tokenizer to use.
image_processor: Required for VL renderers.
+ model_name: Model name for pickle metadata. If None, falls back to
+ ``tokenizer.name_or_path``. Provide this explicitly when the tokenizer
+ was loaded with a remapped name (e.g., Llama 3 models).
Returns:
A Renderer instance.
@@ -68,10 +135,25 @@ def get_renderer(
ValueError: If the renderer name is unknown.
AssertionError: If a VL renderer is requested without an image_processor.
"""
+
+ def _stamp_pickle_metadata(renderer: Renderer) -> Renderer:
+ """Stamp renderer with metadata needed for pickle support."""
+ renderer._renderer_name = name
+ renderer._model_name = (
+ model_name if model_name is not None else tokenizer.name_or_path
+ )
+ renderer._has_image_processor = image_processor is not None
+ return renderer
+
+ # Check custom registry first
+ if (factory := _CUSTOM_RENDERER_REGISTRY.get(name)) is not None:
+ return _stamp_pickle_metadata(factory(tokenizer, image_processor))
+
# Import renderer classes lazily to avoid circular imports and keep exports minimal
from .deepseek_v3 import DeepSeekV3DisableThinkingRenderer
from .gpt_oss import GptOssRenderer
from .kimi_k2 import KimiK2Renderer
+ from .kimi_k25 import KimiK25DisableThinkingRenderer, KimiK25Renderer
from .llama3 import Llama3Renderer
from .qwen3 import (
Qwen3DisableThinkingRenderer,
@@ -79,52 +161,72 @@ def get_renderer(
Qwen3VLInstructRenderer,
Qwen3VLRenderer,
)
+ from .qwen3_5 import Qwen3_5DisableThinkingRenderer, Qwen3_5Renderer
from .role_colon import RoleColonRenderer
+ renderer: Renderer
if name == "role_colon":
- return RoleColonRenderer(tokenizer)
+ renderer = RoleColonRenderer(tokenizer)
elif name == "llama3":
- return Llama3Renderer(tokenizer)
+ renderer = Llama3Renderer(tokenizer)
elif name == "qwen3":
- return Qwen3Renderer(tokenizer)
+ renderer = Qwen3Renderer(tokenizer)
elif name == "qwen3_vl":
assert image_processor is not None, (
"qwen3_vl renderer requires an image_processor"
)
- return Qwen3VLRenderer(tokenizer, image_processor)
+ renderer = Qwen3VLRenderer(tokenizer, image_processor)
elif name == "qwen3_vl_instruct":
assert image_processor is not None, (
"qwen3_vl_instruct renderer requires an image_processor"
)
- return Qwen3VLInstructRenderer(tokenizer, image_processor)
+ renderer = Qwen3VLInstructRenderer(tokenizer, image_processor)
elif name == "qwen3_disable_thinking":
- return Qwen3DisableThinkingRenderer(tokenizer)
+ renderer = Qwen3DisableThinkingRenderer(tokenizer)
elif name == "qwen3_instruct":
- return Qwen3InstructRenderer(tokenizer)
+ renderer = Qwen3InstructRenderer(tokenizer)
+ elif name == "qwen3_5":
+ renderer = Qwen3_5Renderer(tokenizer, image_processor=image_processor)
+ elif name == "qwen3_5_disable_thinking":
+ renderer = Qwen3_5DisableThinkingRenderer(
+ tokenizer, image_processor=image_processor
+ )
elif name == "deepseekv3":
# Default to non-thinking mode (matches HF template default behavior)
- return DeepSeekV3DisableThinkingRenderer(tokenizer)
+ renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)
elif name == "deepseekv3_disable_thinking":
# Alias for backward compatibility
- return DeepSeekV3DisableThinkingRenderer(tokenizer)
+ renderer = DeepSeekV3DisableThinkingRenderer(tokenizer)
elif name == "deepseekv3_thinking":
- return DeepSeekV3ThinkingRenderer(tokenizer)
+ renderer = DeepSeekV3ThinkingRenderer(tokenizer)
elif name == "kimi_k2":
- return KimiK2Renderer(tokenizer)
+ renderer = KimiK2Renderer(tokenizer)
+ elif name == "kimi_k25":
+ renderer = KimiK25Renderer(tokenizer, image_processor=image_processor)
+ elif name == "kimi_k25_disable_thinking":
+ renderer = KimiK25DisableThinkingRenderer(
+ tokenizer, image_processor=image_processor
+ )
elif name == "gpt_oss_no_sysprompt":
- return GptOssRenderer(tokenizer, use_system_prompt=False)
+ renderer = GptOssRenderer(tokenizer, use_system_prompt=False)
elif name == "gpt_oss_low_reasoning":
- return GptOssRenderer(tokenizer, use_system_prompt=True, reasoning_effort="low")
+ renderer = GptOssRenderer(
+ tokenizer, use_system_prompt=True, reasoning_effort="low"
+ )
elif name == "gpt_oss_medium_reasoning":
- return GptOssRenderer(
+ renderer = GptOssRenderer(
tokenizer, use_system_prompt=True, reasoning_effort="medium"
)
elif name == "gpt_oss_high_reasoning":
- return GptOssRenderer(
+ renderer = GptOssRenderer(
tokenizer, use_system_prompt=True, reasoning_effort="high"
)
else:
- raise ValueError(f"Unknown renderer: {name}")
+ raise ValueError(
+ f"Unknown renderer: {name}. If this is a custom renderer, please register it via register_renderer()."
+ )
+
+ return _stamp_pickle_metadata(renderer)
__all__ = [
@@ -137,6 +239,12 @@ def get_renderer(
"ThinkingPart",
"ToolCall",
"ToolSpec",
+ # Streaming types
+ "MessageDelta",
+ "StreamingMessageHeader",
+ "StreamingTextDelta",
+ "StreamingThinkingDelta",
+ "Utf8TokenDecoder",
# Renderer base
"RenderContext",
"Renderer",
@@ -146,6 +254,11 @@ def get_renderer(
"format_content_as_string",
"get_text_content",
"parse_content_blocks",
+ # Registry
+ "register_renderer",
+ "unregister_renderer",
+ "get_registered_renderer_names",
+ "is_renderer_registered",
# Factory
"get_renderer",
# Renderer classes (used by tests)
diff --git a/src/art/tinker/cookbook_v/renderers/base.py b/src/art/tinker/cookbook_v/renderers/base.py
index b46874e9..9f8539e5 100644
--- a/src/art/tinker/cookbook_v/renderers/base.py
+++ b/src/art/tinker/cookbook_v/renderers/base.py
@@ -11,8 +11,9 @@
import io
import json
import logging
+import pickle
import re
-from typing import Literal, NotRequired, Optional, Protocol, TypedDict
+from typing import Any, Literal, NotRequired, Optional, Protocol, TypedDict, Union
import urllib.request
from PIL import Image
@@ -124,23 +125,169 @@ class ThinkingPart(TypedDict):
thinking: str # The thinking/reasoning content
-class ToolCallPart(TypedDict):
- """Tool/function call as a content part, preserving position in content list."""
+# Container for a part of a multimodal message content.
+# Tool calls live exclusively in message["tool_calls"] / message["unparsed_tool_calls"].
+ContentPart = TextPart | ImagePart | ThinkingPart
- type: Literal["tool_call"]
- tool_call: ToolCall # The parsed tool call object
+# Streaming types to enable incremental parsing of model output for real-time display.
-class UnparsedToolCallPart(TypedDict):
- """Tool call that failed to parse, preserving raw text for debugging."""
- type: Literal["unparsed_tool_call"]
- raw_text: str # Raw text of the tool call block including tags
- error: str # Description of what went wrong during parsing
+@dataclass
+class StreamingMessageHeader:
+ """Emitted at the start of a new message during streaming.
+
+ This signals that a new message is beginning and provides the author info.
+ """
+
+ role: str
+ name: str | None = None
+
+
+@dataclass
+class StreamingTextDelta:
+ """Incremental text content during streaming.
+
+ Contains only the new text since the last delta, not the accumulated text.
+ The recipient should concatenate deltas to build the full content.
+ """
+
+ text: str
+ content_index: int = 0
+ """Index of this content block within the message. Increments when content type changes."""
+
+
+@dataclass
+class StreamingThinkingDelta:
+ """Incremental thinking/reasoning content during streaming.
+
+ Contains only the new thinking text since the last delta.
+ """
+
+ thinking: str
+ content_index: int = 0
+ """Index of this content block within the message. Increments when content type changes."""
+
+
+# Union of all streaming update types.
+# A streaming parser yields these in sequence:
+# 1. StreamingMessageHeader (once at start)
+# 2. StreamingTextDelta / StreamingThinkingDelta (as content arrives)
+# 3. Message (once at end, containing the complete parsed message)
+MessageDelta = Union[
+ StreamingMessageHeader, StreamingTextDelta, StreamingThinkingDelta, "Message"
+]
+
+
+# Unicode replacement character - indicates incomplete/invalid UTF-8 sequence
+_REPLACEMENT_CHAR = "\ufffd"
+
+
+@dataclass
+class Utf8TokenDecoder:
+ """Handles incremental UTF-8 decoding from tokens.
+ Tokens can split multi-byte UTF-8 sequences (e.g., a 3-byte character
+ might be split across 2 tokens). This class buffers tokens until a
+ valid UTF-8 string can be decoded.
-# Container for a part of a multimodal message content
-ContentPart = TextPart | ImagePart | ThinkingPart | ToolCallPart | UnparsedToolCallPart
+ Detection strategy:
+ 1. Try decoding all pending + new tokens
+ 2. If result contains trailing U+FFFD (replacement char), it's incomplete
+ 3. Scan backwards to find longest prefix without trailing replacement chars
+ 4. Emit that prefix, buffer the rest
+
+ This handles tiktoken-style tokenizers that return replacement chars
+ instead of throwing exceptions for incomplete UTF-8.
+ """
+
+ tokenizer: "Tokenizer"
+ _pending_tokens: list[int] = None # type: ignore[assignment]
+
+ def __post_init__(self) -> None:
+ if self._pending_tokens is None:
+ self._pending_tokens = []
+
+ # Max tokens to try removing from the end when looking for decodable prefix.
+ # UTF-8 chars are max 4 bytes, tokens typically 1-4 bytes each,
+ # so 8 tokens is plenty to cover any incomplete trailing sequence.
+ _MAX_TRAILING_TOKENS_TO_TRY: int = 8
+
+ def _is_valid_decode(self, text: str) -> bool:
+ """Check if decoded text represents a complete UTF-8 sequence.
+
+ Returns False if the text ends with a replacement character,
+ which indicates an incomplete multi-byte sequence that needs
+ more tokens to complete.
+ """
+ return not text.endswith(_REPLACEMENT_CHAR)
+
+ def decode(self, tokens: list[int]) -> str | None:
+ """Decode tokens to string, buffering incomplete UTF-8 sequences.
+
+ Args:
+ tokens: New tokens to decode.
+
+ Returns:
+ Decoded string if complete UTF-8 sequences are available,
+ None if all tokens were buffered (incomplete sequence).
+ """
+ self._pending_tokens.extend(tokens)
+
+ # Try to decode all pending tokens (common case)
+ try:
+ text = self.tokenizer.decode(self._pending_tokens)
+ if self._is_valid_decode(text):
+ self._pending_tokens = []
+ return text
+ # Has trailing replacement chars - fall through to find valid prefix
+ except Exception:
+ pass
+
+ # Scan backwards to find longest decodable prefix without replacement chars.
+ # We only need to try removing a few tokens since UTF-8 sequences are at
+ # most 4 bytes and tokens are typically 1-4 bytes each.
+ for remove in range(
+ 1, min(len(self._pending_tokens), self._MAX_TRAILING_TOKENS_TO_TRY) + 1
+ ):
+ prefix = self._pending_tokens[:-remove]
+ if not prefix:
+ break
+ try:
+ text = self.tokenizer.decode(prefix)
+ if self._is_valid_decode(text):
+ self._pending_tokens = self._pending_tokens[-remove:]
+ return text
+ except Exception:
+ continue
+
+ # All tokens buffered - need more data
+ return None
+
+ def flush(self) -> str:
+ """Force decode any remaining tokens.
+
+ Call this at end of stream. May produce replacement characters
+ for incomplete sequences.
+ """
+ if not self._pending_tokens:
+ return ""
+ try:
+ text = self.tokenizer.decode(self._pending_tokens)
+ except Exception:
+ # Last resort: decode with errors='replace' behavior
+ # Most tokenizers handle this, but fall back to empty string
+ text = ""
+ self._pending_tokens = []
+ return text
+
+ def reset(self) -> None:
+ """Clear any buffered tokens."""
+ self._pending_tokens = []
+
+ def has_pending(self) -> bool:
+ """Check if there are buffered tokens waiting for more data."""
+ return len(self._pending_tokens) > 0
# NOTE: we use a broad type definition for the role to be flexible
@@ -212,6 +359,14 @@ class RenderContext:
prev_message: Message | None = None
"""The previous message in the conversation, if any."""
+ last_user_index: int = -1
+ """Index of the last user message in the conversation. -1 if no user messages.
+
+ This is computed by the base build_generation_prompt/build_supervised_example
+ and used by renderers like Qwen3.5 that need to treat assistant messages
+ differently based on whether they come before or after the last user message.
+ """
+
class ToolSpec(TypedDict):
"""
@@ -287,7 +442,7 @@ def format_content_as_string(content: Content, separator: str = "\n") -> str:
"""Format message content as a string, preserving all part types.
Unlike get_text_content which only extracts text parts, this formats
- all content parts (thinking, text, tool_call, etc.) as a readable string.
+ all content parts (thinking, text) as a readable string.
This is useful for compatibility with APIs that expect string content
(e.g., OpenAI Chat Completions API), but we don't recommend it if you
@@ -310,13 +465,6 @@ def format_content_as_string(content: Content, separator: str = "\n") -> str:
parts.append(f"{p['thinking']}")
elif p["type"] == "text":
parts.append(p["text"])
- elif p["type"] == "tool_call":
- tc = p["tool_call"]
- parts.append(
- f"{tc.function.name}({tc.function.arguments})"
- )
- elif p["type"] == "unparsed_tool_call":
- parts.append(f"{p['raw_text']}")
else:
raise ValueError(f"Unknown content part type: {p['type']}")
return separator.join(parts)
@@ -362,37 +510,42 @@ def _parse_tool_call_json(
)
-def parse_content_blocks(content: str) -> list[ContentPart] | None:
+def parse_content_blocks(
+ content: str,
+) -> tuple[list[ContentPart], list[ToolCall | UnparsedToolCall]] | None:
"""
Parse a string with ... and ... tags.
- Handles interleaved thinking, tool call, and text blocks, returning parts
- in order. Empty parts are omitted. Failed tool call parses are included as
- UnparsedToolCallPart to preserve ordering.
+ Handles interleaved thinking, tool call, and text blocks. Content parts
+ (ThinkingPart, TextPart) are returned in the first element; tool calls
+ (ToolCall, UnparsedToolCall) are returned separately in the second element,
+ preserving their relative order.
- Whitespace is preserved exactly - roundtrip (parse then render) is identity.
+ Whitespace in non-tool-call regions is preserved exactly - roundtrip
+ (parse then render) is identity for the content parts.
Args:
content: String potentially containing and/or blocks.
Returns:
- List of ContentPart (ThinkingPart, TextPart, ToolCallPart, UnparsedToolCallPart)
- in order. Returns None if no special tags are found - caller should use
- the original string for backward compatibility.
+ Tuple of (content_parts, tool_calls), or None if no special tags are found.
+ content_parts contains only ThinkingPart/TextPart.
+ tool_calls contains ToolCall and UnparsedToolCall in order.
Example:
>>> parse_content_blocks("step 1answer{...}more")
- [
- ThinkingPart(type="thinking", thinking="step 1"),
- TextPart(type="text", text="answer"),
- ToolCallPart(type="tool_call", tool_call=ToolCall(...)),
- TextPart(type="text", text="more"),
- ]
+ (
+ [ThinkingPart(type="thinking", thinking="step 1"),
+ TextPart(type="text", text="answer"),
+ TextPart(type="text", text="more")],
+ [ToolCall(...)],
+ )
"""
if "" not in content and "" not in content:
return None # No special blocks, caller should use original string
parts: list[ContentPart] = []
+ tool_calls: list[ToolCall | UnparsedToolCall] = []
pos = 0
# Pattern to find both ... and ... blocks
@@ -412,21 +565,10 @@ def parse_content_blocks(content: str) -> list[ContentPart] | None:
if thinking: # Skip empty thinking blocks
parts.append(ThinkingPart(type="thinking", thinking=thinking))
else:
- # This is a block
+ # This is a block — goes into separate tool_calls list
tool_call_json = match.group(2)
raw_text = match.group(0) # Full match including tags
- parsed = _parse_tool_call_json(tool_call_json, raw_text)
- if isinstance(parsed, UnparsedToolCall):
- # Include unparsed tool calls as UnparsedToolCallPart to preserve order
- parts.append(
- UnparsedToolCallPart(
- type="unparsed_tool_call",
- raw_text=parsed.raw_text,
- error=parsed.error,
- )
- )
- else:
- parts.append(ToolCallPart(type="tool_call", tool_call=parsed))
+ tool_calls.append(_parse_tool_call_json(tool_call_json, raw_text))
pos = match.end()
@@ -435,7 +577,7 @@ def parse_content_blocks(content: str) -> list[ContentPart] | None:
if remaining: # Skip only truly empty strings
parts.append(TextPart(type="text", text=remaining))
- return parts
+ return parts, tool_calls
def parse_think_blocks(content: str) -> list[ContentPart] | None:
@@ -534,6 +676,7 @@ class RenderedMessage:
class TrainOnWhat(StrEnum):
LAST_ASSISTANT_MESSAGE = "last_assistant_message"
+ LAST_ASSISTANT_TURN = "last_assistant_turn"
ALL_ASSISTANT_MESSAGES = "all_assistant_messages"
ALL_MESSAGES = "all_messages"
ALL_TOKENS = "all_tokens"
@@ -541,6 +684,28 @@ class TrainOnWhat(StrEnum):
CUSTOMIZED = "customized"
+def _unpickle_renderer(
+ renderer_name: str, model_name: str, has_image_processor: bool
+) -> "Renderer":
+ """Reconstruct a Renderer from its name and model name.
+
+ Called by pickle to deserialize Renderer instances. Uses cached tokenizer/image_processor
+ so reconstruction cost is negligible after first call.
+ """
+ from ..tokenizer_utils import get_tokenizer
+ from . import get_renderer
+
+ tokenizer = get_tokenizer(model_name)
+ image_processor = None
+ if has_image_processor:
+ from ..image_processing_utils import get_image_processor
+
+ image_processor = get_image_processor(model_name)
+ return get_renderer(
+ renderer_name, tokenizer, image_processor, model_name=model_name
+ )
+
+
class Renderer(ABC):
"""
Abstract base class for rendering message lists into training and sampling prompts.
@@ -553,13 +718,47 @@ class Renderer(ABC):
The default build_generation_prompt and build_supervised_example implementations
assume simple concatenation of rendered messages. Override these if your renderer
modifies the conversation structure (e.g., stripping thinking blocks from history).
+
+ Pickle support: Renderers created via ``get_renderer()`` are automatically pickleable.
+ On deserialization, the tokenizer and image processor are reconstructed from cached
+ loaders, so the cost is negligible. Renderers created directly (not via ``get_renderer()``)
+ must set ``_renderer_name`` and ``_model_name`` manually to be pickleable.
+
+ Implementations of ``EnvGroupBuilder`` must be pickleable to support distributed rollout
+ execution. Since many builders store a Renderer, this pickle support is critical.
"""
tokenizer: Tokenizer
+ # Pickle metadata — set by get_renderer() via _stamp_pickle_metadata().
+ # Class-level defaults ensure these exist even when subclasses bypass super().__init__().
+ _renderer_name: str | None = None
+ _model_name: str | None = None
+ _has_image_processor: bool = False
+
def __init__(self, tokenizer: Tokenizer):
self.tokenizer = tokenizer
+ def __reduce__(self) -> tuple:
+ """Enable pickling by storing only (renderer_name, model_name, has_image_processor).
+
+ On unpickling, the Renderer is reconstructed via get_renderer() with a
+ cached tokenizer, so the cost is negligible.
+ """
+ renderer_name = getattr(self, "_renderer_name", None)
+ model_name = getattr(self, "_model_name", None)
+ has_image_processor = getattr(self, "_has_image_processor", False)
+ if renderer_name is None or model_name is None:
+ raise pickle.PicklingError(
+ f"Cannot pickle {type(self).__name__}: _renderer_name or _model_name not set. "
+ "Renderers must be created via get_renderer() to be pickleable, "
+ "or set _renderer_name and _model_name manually."
+ )
+ return (
+ _unpickle_renderer,
+ (renderer_name, model_name, has_image_processor),
+ )
+
@property
def has_extension_property(self) -> bool:
"""Whether this renderer satisfies the sequence extension property.
@@ -669,7 +868,6 @@ def to_openai_message(self, message: Message) -> dict:
"Images would be silently dropped, leading to incorrect HF template "
"comparisons or OpenAI API calls. Use build_generation_prompt for VL models."
)
- # Skip tool_call and unparsed_tool_call parts - handled via tool_calls field
result["content"] = "".join(parts)
# Handle tool_calls (convert ToolCall objects to OpenAI format)
@@ -758,23 +956,38 @@ def build_generation_prompt(
chunks: list[tinker.types.ModelInputChunk] = []
if self._bos_tokens:
chunks.append(tinker.types.EncodedTextChunk(tokens=self._bos_tokens))
+
+ last_user_idx = max(
+ (idx for idx, message in enumerate(messages) if message["role"] == "user"),
+ default=-1,
+ )
+
for idx, message in enumerate(messages):
ctx = RenderContext(
idx=idx,
is_last=(idx == len(messages) - 1),
prev_message=messages[idx - 1] if idx > 0 else None,
+ last_user_index=last_user_idx,
)
rendered_message = self.render_message(message, ctx)
header_chunk = rendered_message.header
output_chunks = rendered_message.output
if header_chunk:
chunks.append(header_chunk)
- chunks.extend([x for x in output_chunks if x])
+ # Filter out empty EncodedTextChunks, which cause 400 errors in model requests
+ chunks.extend(
+ [
+ x
+ for x in output_chunks
+ if not isinstance(x, tinker.EncodedTextChunk) or x.tokens
+ ]
+ )
suffix_ctx = RenderContext(
idx=len(messages),
is_last=True,
prev_message=messages[-1] if messages else None,
+ last_user_index=last_user_idx,
)
suffix_tokens = self._get_generation_suffix(role, suffix_ctx)
if suffix_tokens:
@@ -788,6 +1001,30 @@ def build_generation_prompt(
)
return tinker.ModelInput(chunks=chunks)
+ def build_supervised_examples(
+ self,
+ messages: list[Message],
+ train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_TURN,
+ ) -> list[tuple[tinker.ModelInput, torch.Tensor]]:
+ """
+ Build tokens and per-token weights for supervised fine-tuning.
+ This function returns a list of examples in the form of tuples, where each tuple contains a model input and a tensor of weights.
+ This is needed because some renderers do not satisfy the extension property, so we need to return a list of examples instead of a single example.
+
+ This default implementation concatenates rendered messages in order, which assumes the renderer satisfies the extension property.
+ Override this method if your renderer does not satisfy the extension property.
+ """
+
+ if self.has_extension_property:
+ return [
+ self.build_supervised_example(messages, train_on_what=train_on_what)
+ ]
+ else:
+ # TODO: Add a default implementation that calls `build_supervised_example` for each message and merges examples with shared prefixes.
+ raise NotImplementedError(
+ "build_supervised_examples has not been implemented for this renderer."
+ )
+
def build_supervised_example(
self,
messages: list[Message],
@@ -810,6 +1047,7 @@ def build_supervised_example(
messages: A list of messages to render.
train_on_what: Controls which tokens receive non-zero training weight:
- LAST_ASSISTANT_MESSAGE: Only the last assistant message
+ - LAST_ASSISTANT_TURN: The last assistant message after the last user message
- ALL_ASSISTANT_MESSAGES: All assistant messages
- ALL_MESSAGES: All messages (but not headers)
- ALL_TOKENS: Everything including headers
@@ -848,6 +1086,11 @@ def build_supervised_example(
(tinker.types.EncodedTextChunk(tokens=self._bos_tokens), 0.0)
)
+ last_user_idx = max(
+ (idx for idx, message in enumerate(messages) if message["role"] == "user"),
+ default=-1,
+ )
+
for idx, message in enumerate(messages):
if train_on_what == TrainOnWhat.CUSTOMIZED:
assert "trainable" in message, (
@@ -861,12 +1104,14 @@ def build_supervised_example(
is_last_message = idx == len(messages) - 1
is_assistant = message["role"] == "assistant"
is_user_or_system = message["role"] in ["user", "system"]
+ is_after_last_user = last_user_idx == -1 or idx > last_user_idx
# only apply weight to header if train_on_what is ALL_TOKENS
ctx = RenderContext(
idx=idx,
is_last=is_last_message,
prev_message=messages[idx - 1] if idx > 0 else None,
+ last_user_index=last_user_idx,
)
rendered_message = self.render_message(message, ctx)
header_part = rendered_message.header
@@ -880,6 +1125,8 @@ def build_supervised_example(
match train_on_what:
case TrainOnWhat.LAST_ASSISTANT_MESSAGE:
output_has_weight = is_last_message and is_assistant
+ case TrainOnWhat.LAST_ASSISTANT_TURN:
+ output_has_weight = is_assistant and is_after_last_user
case TrainOnWhat.ALL_ASSISTANT_MESSAGES:
output_has_weight = is_assistant
case TrainOnWhat.ALL_MESSAGES:
@@ -958,9 +1205,7 @@ def parse_response_for_stop_token(
)
-# ============================================================================
# Image processing utilities (used by VL renderers)
-# ============================================================================
class ImageProcessorProtocol(Protocol):
@@ -972,6 +1217,9 @@ def get_number_of_image_patches(
) -> int:
raise NotImplementedError()
+ def get_resize_config(self, image_data: dict[str, Any]) -> dict[str, Any]:
+ raise NotImplementedError()
+
def image_to_chunk(
image_or_str: Image.Image | str, image_processor: ImageProcessorProtocol
@@ -1003,11 +1251,23 @@ def image_to_chunk(
pil_image.save(img_byte_arr, format="JPEG")
image_data = img_byte_arr.getvalue()
- width, height = pil_image.size
- num_image_tokens = (
- image_processor.get_number_of_image_patches(height, width, images_kwargs={})
- // image_processor.merge_size**2
- )
+ # Get the number of expected tokens for the image. The way to do this is not consistent between
+ # image processors (qwen3vl supports get_number_of_image_patches, kimi2.5 doesn't but has get_resize_config)
+ if hasattr(image_processor, "get_number_of_image_patches"):
+ width, height = pil_image.size
+ num_image_tokens = (
+ image_processor.get_number_of_image_patches(height, width, images_kwargs={})
+ // image_processor.merge_size**2
+ )
+ elif hasattr(image_processor, "get_resize_config"):
+ config = image_processor.get_resize_config(
+ {"type": "image", "image": pil_image}
+ )
+ num_image_tokens = config["num_tokens"]
+ else:
+ raise ValueError(
+ f"Don't know how to get the number of image tokens for image processor: {image_processor}"
+ )
return tinker.types.ImageChunk(
data=image_data,
diff --git a/src/art/tinker/cookbook_v/renderers/deepseek_v3.py b/src/art/tinker/cookbook_v/renderers/deepseek_v3.py
index a16b4843..148a6829 100644
--- a/src/art/tinker/cookbook_v/renderers/deepseek_v3.py
+++ b/src/art/tinker/cookbook_v/renderers/deepseek_v3.py
@@ -127,7 +127,6 @@ def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessag
rendered_parts.append(f"{p['thinking']}")
elif p["type"] == "text":
rendered_parts.append(p["text"])
- # ToolCallPart handled via message's tool_calls field
output_content = "".join(rendered_parts)
else:
# String content - pass through as-is.
diff --git a/src/art/tinker/cookbook_v/renderers/kimi_k2.py b/src/art/tinker/cookbook_v/renderers/kimi_k2.py
index dd087a2e..2ac5ab42 100644
--- a/src/art/tinker/cookbook_v/renderers/kimi_k2.py
+++ b/src/art/tinker/cookbook_v/renderers/kimi_k2.py
@@ -1,22 +1,32 @@
"""Renderer for Moonshot AI's Kimi K2 models."""
+from dataclasses import dataclass, field
import json
import re
+from typing import Iterator
import warnings
import tinker
import torch
+from ..tokenizer_utils import Tokenizer
from .base import (
+ ContentPart,
Message,
+ MessageDelta,
RenderContext,
RenderedMessage,
Renderer,
Role,
+ StreamingMessageHeader,
+ StreamingTextDelta,
+ StreamingThinkingDelta,
+ TextPart,
ToolCall,
ToolSpec,
TrainOnWhat,
UnparsedToolCall,
+ Utf8TokenDecoder,
ensure_list,
ensure_text,
parse_response_for_stop_token,
@@ -79,6 +89,269 @@ def _parse_tool_calls_section(
return tool_calls, unparsed_tool_calls
+# =============================================================================
+# Streaming Parser
+# =============================================================================
+
+# Tags we need to detect during streaming
+_THINK_OPEN_TAG = ""
+_THINK_CLOSE_TAG = ""
+
+
+def _longest_matching_suffix_prefix(text: str, tag: str) -> int:
+ """Find longest suffix of text that matches a prefix of tag.
+
+ This is used during streaming to determine how many characters at the end
+ of accumulated text might be the beginning of a tag, and thus shouldn't
+ be emitted yet.
+
+ Args:
+ text: The accumulated text to check.
+ tag: The tag we're looking for (e.g., "").
+
+ Returns:
+ Length of the longest suffix of text that matches a prefix of tag.
+
+ Examples:
+ >>> _longest_matching_suffix_prefix("hello", "")
+ 0 # no suffix matches any prefix
+ >>> _longest_matching_suffix_prefix("hello<", "")
+ 1 # "<" matches prefix "<"
+ >>> _longest_matching_suffix_prefix("hello| ")
+ 3 # " | >> _longest_matching_suffix_prefix("hello")
+ 0 # ""
+ """
+ max_check = min(
+ len(text), len(tag) - 1
+ ) # -1 because full tag would be found, not buffered
+ for length in range(max_check, 0, -1):
+ if text.endswith(tag[:length]):
+ return length
+ return 0
+
+
+@dataclass
+class KimiK2StreamingParser:
+ """Stateful streaming parser for Kimi K2 model output.
+
+ Parses tokens incrementally, yielding deltas as content becomes available.
+ Handles UTF-8 token boundaries and ... block transitions.
+
+ Usage:
+ parser = KimiK2StreamingParser(tokenizer, end_token)
+ for token in response_tokens:
+ for delta in parser.feed(token):
+ # Handle delta (StreamingMessageHeader, StreamingTextDelta, etc.)
+ for delta in parser.finish():
+ # Handle final deltas including complete Message
+ """
+
+ tokenizer: Tokenizer
+ end_message_token: int
+
+ # Internal state
+ _utf8_decoder: Utf8TokenDecoder = field(init=False)
+ _accumulated_text: str = field(init=False, default="")
+ _header_emitted: bool = field(init=False, default=False)
+ _in_thinking: bool = field(init=False, default=False)
+ _content_index: int = field(init=False, default=0)
+ _last_emitted_pos: int = field(init=False, default=0)
+ _finished: bool = field(init=False, default=False)
+ _all_tokens: list[int] = field(init=False, default_factory=list)
+
+ def __post_init__(self) -> None:
+ self._utf8_decoder = Utf8TokenDecoder(self.tokenizer)
+ self._accumulated_text = ""
+ self._header_emitted = False
+ self._in_thinking = False
+ self._content_index = 0
+ self._last_emitted_pos = 0
+ self._finished = False
+ self._all_tokens = []
+
+ def feed(self, token: int) -> Iterator[MessageDelta]:
+ """Feed a single token and yield any resulting deltas.
+
+ Args:
+ token: A single token ID from the model output.
+
+ Yields:
+ MessageDelta objects as content becomes available.
+ """
+ if self._finished:
+ return
+
+ self._all_tokens.append(token)
+
+ # Check for end token
+ if token == self.end_message_token:
+ self._finished = True
+ return
+
+ # Try to decode the token
+ decoded = self._utf8_decoder.decode([token])
+ if decoded is None:
+ # Token was buffered (incomplete UTF-8 sequence)
+ return
+
+ self._accumulated_text += decoded
+
+ # Emit header on first content
+ if not self._header_emitted:
+ self._header_emitted = True
+ yield StreamingMessageHeader(role="assistant")
+
+ # Process the new text for deltas
+ yield from self._emit_deltas()
+
+ def _emit_deltas(self) -> Iterator[MessageDelta]:
+ """Emit deltas for any new content since last emission."""
+ text = self._accumulated_text
+ pos = self._last_emitted_pos
+
+ while pos < len(text):
+ if not self._in_thinking:
+ # Look for tag
+ think_start = text.find(_THINK_OPEN_TAG, pos)
+ if think_start == -1:
+ # No tag found - emit text up to a safe point.
+ # Keep any trailing chars that could be the start of "".
+ suffix_from_pos = text[pos:]
+ keep = _longest_matching_suffix_prefix(
+ suffix_from_pos, _THINK_OPEN_TAG
+ )
+ safe_end = len(text) - keep
+ if safe_end > pos:
+ new_text = text[pos:safe_end]
+ if new_text:
+ yield StreamingTextDelta(
+ text=new_text, content_index=self._content_index
+ )
+ self._last_emitted_pos = safe_end
+ break
+ elif think_start > pos:
+ # Emit text before
+ new_text = text[pos:think_start]
+ if new_text:
+ yield StreamingTextDelta(
+ text=new_text, content_index=self._content_index
+ )
+ pos = think_start
+
+ if text[pos:].startswith(_THINK_OPEN_TAG):
+ # Enter thinking mode
+ self._in_thinking = True
+ self._content_index += 1
+ pos += len(_THINK_OPEN_TAG)
+ self._last_emitted_pos = pos
+ else:
+ # In thinking mode - look for
+ think_end = text.find(_THINK_CLOSE_TAG, pos)
+ if think_end == -1:
+ # No found - emit thinking up to safe point.
+ # Keep any trailing chars that could be the start of "".
+ suffix_from_pos = text[pos:]
+ keep = _longest_matching_suffix_prefix(
+ suffix_from_pos, _THINK_CLOSE_TAG
+ )
+ safe_end = len(text) - keep
+ if safe_end > pos:
+ new_thinking = text[pos:safe_end]
+ if new_thinking:
+ yield StreamingThinkingDelta(
+ thinking=new_thinking, content_index=self._content_index
+ )
+ self._last_emitted_pos = safe_end
+ break
+ else:
+ # Emit thinking before
+ new_thinking = text[pos:think_end]
+ if new_thinking:
+ yield StreamingThinkingDelta(
+ thinking=new_thinking, content_index=self._content_index
+ )
+ # Exit thinking mode
+ self._in_thinking = False
+ self._content_index += 1
+ pos = think_end + len(_THINK_CLOSE_TAG)
+ self._last_emitted_pos = pos
+
+ def finish(self) -> Iterator[MessageDelta]:
+ """Finish parsing and yield any remaining content plus final Message.
+
+ Call this after all tokens have been fed (either naturally when
+ end_message_token is seen, or when the stream ends).
+
+ Yields:
+ Any remaining deltas and the complete Message.
+ """
+ # Flush any buffered UTF-8 tokens
+ remaining = self._utf8_decoder.flush()
+ if remaining:
+ self._accumulated_text += remaining
+
+ # Emit header if we haven't yet (empty response edge case)
+ if not self._header_emitted:
+ self._header_emitted = True
+ yield StreamingMessageHeader(role="assistant")
+
+ # Emit any remaining content
+ text = self._accumulated_text
+ pos = self._last_emitted_pos
+
+ if pos < len(text):
+ remaining_text = text[pos:]
+ if self._in_thinking:
+ # Unclosed thinking block - emit as thinking
+ if remaining_text:
+ yield StreamingThinkingDelta(
+ thinking=remaining_text, content_index=self._content_index
+ )
+ else:
+ if remaining_text:
+ yield StreamingTextDelta(
+ text=remaining_text, content_index=self._content_index
+ )
+
+ # Build and yield the final complete Message
+ # Use the batch parser for consistency
+ message, _success = parse_response_for_stop_token(
+ self._all_tokens, self.tokenizer, self.end_message_token
+ )
+
+ content = message.get("content", "")
+ if isinstance(content, str):
+ # Handle tool calls if present
+ text_content, tool_section = _split_tool_calls_section(content)
+ if tool_section is not None:
+ tool_calls, unparsed_tool_calls = _parse_tool_calls_section(
+ tool_section
+ )
+ if tool_calls:
+ message["tool_calls"] = tool_calls
+ if unparsed_tool_calls:
+ message["unparsed_tool_calls"] = unparsed_tool_calls
+
+ content_parts = parse_think_blocks(text_content)
+ message["content"] = (
+ content_parts if content_parts is not None else text_content
+ )
+
+ yield message
+
+ def reset(self) -> None:
+ """Reset parser state for reuse."""
+ self._utf8_decoder.reset()
+ self._accumulated_text = ""
+ self._header_emitted = False
+ self._in_thinking = False
+ self._content_index = 0
+ self._last_emitted_pos = 0
+ self._finished = False
+ self._all_tokens = []
+
+
class KimiK2Renderer(Renderer):
"""
Format for moonshotai/Kimi-K2-Thinking:
@@ -86,8 +359,8 @@ class KimiK2Renderer(Renderer):
<|im_user|>user<|im_middle|>What can you help me with?<|im_end|>
<|im_assistant|>assistant<|im_middle|>reasoningI can help you with...<|im_end|>
- Historical assistant messages use empty blocks, while the final assistant
- response preserves reasoning_content in the thinking block.
+ Historical assistant messages use empty blocks, while the assistant messages after the
+ last non-tool-call assistant message preserves reasoning_content in the thinking block.
Note: Per the HuggingFace chat template, the default system message is automatically
prepended if no system message is provided. This ensures train-eval consistency when
@@ -96,6 +369,10 @@ class KimiK2Renderer(Renderer):
DEFAULT_SYSTEM_PROMPT = "You are Kimi, an AI assistant created by Moonshot AI."
+ def __init__(self, tokenizer: Tokenizer, strip_thinking_from_history: bool = True):
+ super().__init__(tokenizer)
+ self.strip_thinking_from_history = strip_thinking_from_history
+
def _ensure_system_message(self, messages: list[Message]) -> list[Message]:
"""Ensure a default system message is present if none exists.
@@ -172,17 +449,22 @@ def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessag
header_str = f"<|im_system|>{role}<|im_middle|>"
# Build output content
- output_str = ""
+ content = message["content"]
+ output: list[tinker.ModelInputChunk] = []
if role == "assistant":
+ output_str = ""
# Extract thinking and text from content list
- parts = ensure_list(message["content"])
+ parts = ensure_list(content)
thinking_content = "".join(
p["thinking"] for p in parts if p["type"] == "thinking"
)
text_content = "".join(p["text"] for p in parts if p["type"] == "text")
- # For the last assistant message (is_last=True), preserve thinking; otherwise use empty think block
- if ctx.is_last and thinking_content:
+ # Preserve thinking for the last assistant message, or for all messages
+ # when strip_thinking_from_history is False.
+ if (
+ ctx.is_last or not self.strip_thinking_from_history
+ ) and thinking_content:
output_str = f"{thinking_content}"
else:
output_str = ""
@@ -198,17 +480,38 @@ def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessag
args = tool_call.function.arguments
output_str += f"<|tool_call_begin|>{tool_id}<|tool_call_argument_begin|>{args}<|tool_call_end|>"
output_str += "<|tool_calls_section_end|>"
+ output_str += "<|im_end|>"
+ output.append(
+ tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(output_str))
+ )
+ elif isinstance(content, str) or (
+ len(content) == 1 and content[0]["type"] == "text"
+ ):
+ # Single-part/text content
+ output_str = ensure_text(content) + "<|im_end|>"
+ output.append(
+ tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(output_str))
+ )
else:
- output_str = ensure_text(message["content"])
-
- output_str += "<|im_end|>"
+ # Mult-part content (e.g. text+image(s))
+ assert isinstance(content, list), (
+ f"Expected list of content parts, got {type(content)}"
+ )
+ output = self._encode_multipart_content(
+ content + [TextPart(type="text", text="<|im_end|>")]
+ )
header = tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(header_str))
- output: list[tinker.ModelInputChunk] = [
- tinker.types.EncodedTextChunk(tokens=self.tokenizer.encode(output_str))
- ]
+
return RenderedMessage(header=header, output=output)
+ def _encode_multipart_content(
+ self, content: list[ContentPart]
+ ) -> list[tinker.ModelInputChunk]:
+ raise NotImplementedError(
+ "Multipart/Image content encoding is not supported for Kimi K2 renderer"
+ )
+
def build_generation_prompt(
self,
messages: list[Message],
@@ -218,11 +521,26 @@ def build_generation_prompt(
messages = self._ensure_system_message(messages)
chunks: list[tinker.types.ModelInputChunk] = []
+ # Find last assistant message without tool calls (matches hf template behavior).
+ last_assistant_idx = -1
+ for idx in range(len(messages) - 1, -1, -1):
+ if messages[idx]["role"] == "assistant" and not messages[idx].get(
+ "tool_calls"
+ ):
+ last_assistant_idx = idx
+ break
+
for idx, message in enumerate(messages):
- # For generation prompt, no message is "last assistant" since we're generating new response
+ is_assistant = message["role"] == "assistant"
+ is_last_assistant = is_assistant and (
+ last_assistant_idx == -1 or idx > last_assistant_idx
+ )
+
+ # We cannot simply set is_last=False since we might be generating a new assistant message following a tool response,
+ # and we need to preserve the thinking that leads to the tool call.
ctx = RenderContext(
idx=idx,
- is_last=False,
+ is_last=is_last_assistant,
prev_message=messages[idx - 1] if idx > 0 else None,
)
rendered_message = self.render_message(message, ctx)
@@ -243,6 +561,57 @@ def build_generation_prompt(
)
return tinker.ModelInput(chunks=chunks)
+ def build_supervised_examples(
+ self,
+ messages: list[Message],
+ train_on_what: TrainOnWhat = TrainOnWhat.LAST_ASSISTANT_TURN,
+ ) -> list[tuple[tinker.ModelInput, torch.Tensor]]:
+ """
+ Build tokens and per-token weights for supervised fine-tuning. Since Kimi K2 renderer does not satisfy the extension property, this method is provided to return multiple examples in case we want to train on multiple assistant messages, potentially across multiple turns of user-assistant conversation.
+ """
+
+ if (
+ train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE
+ or train_on_what == TrainOnWhat.LAST_ASSISTANT_TURN
+ ):
+ return [
+ self.build_supervised_example(messages, train_on_what=train_on_what)
+ ]
+
+ # split the messages into turns by user messages
+ user_message_idxs = [
+ idx for idx, message in enumerate(messages) if message["role"] == "user"
+ ]
+
+ supervised_examples: list[tuple[tinker.ModelInput, torch.Tensor]] = []
+
+ if train_on_what != TrainOnWhat.ALL_ASSISTANT_MESSAGES:
+ warnings.warn(
+ "WARNING: Using train_on_what=ALL_MESSAGES/ALL_TOKENS/ALL_USER_AND_SYSTEM_MESSAGES/CUSTOMIZED with a renderer that "
+ "does not satisfy the extension property (has_extension_property=False). "
+ "The behavior is we apply the same `train_on_what` to all turns. This may not be the desired behavior.",
+ UserWarning,
+ stacklevel=3,
+ )
+
+ # We separate the turns by user messages. The first turn is the messages before the second user message.
+ for user_message_idx in [*user_message_idxs[1:], len(messages)]:
+ current_messages = messages[:user_message_idx]
+ if train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES:
+ supervised_examples.append(
+ self.build_supervised_example(
+ current_messages, train_on_what=TrainOnWhat.LAST_ASSISTANT_TURN
+ )
+ )
+ else:
+ supervised_examples.append(
+ self.build_supervised_example(
+ current_messages, train_on_what=train_on_what
+ )
+ )
+
+ return supervised_examples
+
def build_supervised_example(
self,
messages: list[Message],
@@ -254,12 +623,14 @@ def build_supervised_example(
"""
messages = self._ensure_system_message(messages)
- # Find last non-tool-call assistant message index
+ # Kimi K2 hf template preserves the thinking of the assistant messages after the last non-tool-call assistant message.
+ # We do the same in general. However, we intentionally skip the last message (which differs from HF template behavior) since for a complete conversation,
+ # we would want to preserve the thinking of the last round of conversation between the user and the assistant (which could include multiple assistant messages and tool calls).
+ # This is because the trajectory would then be taken for SFT without losing all the thinking content.
last_assistant_idx = -1
- for idx in range(len(messages) - 1, -1, -1):
- if (
- messages[idx]["role"] == "assistant"
- and "tool_calls" not in messages[idx]
+ for idx in range(len(messages) - 2, -1, -1):
+ if messages[idx]["role"] == "assistant" and not messages[idx].get(
+ "tool_calls"
):
last_assistant_idx = idx
break
@@ -278,17 +649,21 @@ def build_supervised_example(
"When using non-CUSTOMIZED train_on_what, each message must not have a trainable field"
)
- is_last_message = idx == len(messages) - 1
is_assistant = message["role"] == "assistant"
+ is_last_message = idx == len(messages) - 1
is_user_or_system = message["role"] in ["user", "system"]
# For Kimi K2, preserve thinking only for the suffix after the last non-tool-call assistant.
- is_last_assistant = (
- is_assistant and last_assistant_idx != -1 and idx >= last_assistant_idx
+ # If no such assistant exists, the suffix is the entire message list.
+ # Preserve thinking only for assistants after the last non-tool-call assistant.
+ is_last_assistant_turn = is_assistant and (
+ last_assistant_idx == -1 or idx > last_assistant_idx
)
+
+ is_last_assistant = is_assistant and is_last_message
ctx = RenderContext(
idx=idx,
- is_last=is_last_assistant,
+ is_last=is_last_assistant_turn,
prev_message=messages[idx - 1] if idx > 0 else None,
)
rendered_message = self.render_message(message, ctx)
@@ -300,9 +675,12 @@ def build_supervised_example(
if header_part:
model_input_chunks_weights += [(header_part, header_weight)]
+ # We include all assistant messages in the last round of assistant-tool interactions as the last assistant message.
match train_on_what:
case TrainOnWhat.LAST_ASSISTANT_MESSAGE:
- output_has_weight = is_last_message and is_assistant
+ output_has_weight = is_last_assistant
+ case TrainOnWhat.LAST_ASSISTANT_TURN:
+ output_has_weight = is_last_assistant_turn
case TrainOnWhat.ALL_ASSISTANT_MESSAGES:
output_has_weight = is_assistant
case TrainOnWhat.ALL_MESSAGES:
@@ -367,6 +745,43 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
return assistant_message, True
+ def parse_response_streaming(self, response: list[int]) -> Iterator[MessageDelta]:
+ """Parse response tokens with streaming, yielding incremental deltas.
+
+ This method enables real-time display of model output by yielding
+ partial content as it becomes available, rather than waiting for
+ the complete response.
+
+ Args:
+ response: Token IDs from the model.
+
+ Yields:
+ StreamingMessageHeader: Once at the start of the message.
+ StreamingTextDelta: Incremental text content.
+ StreamingThinkingDelta: Incremental thinking content.
+ Message: The complete parsed message at the end.
+
+ Example:
+ for delta in renderer.parse_response_streaming(tokens):
+ if isinstance(delta, StreamingMessageHeader):
+ print(f"New message from {delta.role}")
+ elif isinstance(delta, StreamingThinkingDelta):
+ print(f"[thinking] {delta.thinking}", end="")
+ elif isinstance(delta, StreamingTextDelta):
+ print(delta.text, end="")
+ elif isinstance(delta, Message):
+ print(f"\\nComplete: {delta}")
+ """
+ parser = KimiK2StreamingParser(
+ tokenizer=self.tokenizer,
+ end_message_token=self._end_message_token,
+ )
+
+ for token in response:
+ yield from parser.feed(token)
+
+ yield from parser.finish()
+
def to_openai_message(self, message: Message) -> dict:
"""Convert a Message to OpenAI API format with reasoning_content for thinking.
diff --git a/src/art/tinker/cookbook_v/renderers/kimi_k25.py b/src/art/tinker/cookbook_v/renderers/kimi_k25.py
new file mode 100644
index 00000000..9fbce3de
--- /dev/null
+++ b/src/art/tinker/cookbook_v/renderers/kimi_k25.py
@@ -0,0 +1,158 @@
+"""Renderer for Moonshot AI's Kimi K2.5 models."""
+
+from typing import cast
+
+import tinker
+
+from ..image_processing_utils import ImageProcessor
+from ..tokenizer_utils import Tokenizer
+from .base import (
+ ContentPart,
+ ImageProcessorProtocol,
+ Message,
+ Role,
+ ToolSpec,
+ image_to_chunk,
+)
+from .kimi_k2 import KimiK2Renderer
+from .kimi_k2_5_tool_declaration_ts import encode_tools_to_typescript_style
+
+
+class KimiK25Renderer(KimiK2Renderer):
+ """
+ Renderer for Kimi K2.5 with thinking enabled (default).
+
+ Key differences from KimiK2Renderer:
+ 1. Generation prompt prefill: Appends `` (open tag) to enable thinking mode
+ 2. Tool declarations: Uses TypeScript-style format instead of JSON
+
+ Format:
+ <|im_system|>system<|im_middle|>You are Kimi...<|im_end|>
+ <|im_user|>user<|im_middle|>Hello<|im_end|>
+ <|im_assistant|>assistant<|im_middle|>
+
+ Historical assistant messages use empty blocks (inherited from K2),
+ while the generation prompt adds an open tag to enable thinking.
+ """
+
+ image_processor: ImageProcessor | None
+
+ def __init__(
+ self,
+ tokenizer: Tokenizer,
+ image_processor: ImageProcessor | None = None,
+ strip_thinking_from_history: bool = True,
+ ):
+ super().__init__(
+ tokenizer, strip_thinking_from_history=strip_thinking_from_history
+ )
+ self.image_processor = image_processor
+
+ def _encode_multipart_content(
+ self, content: list[ContentPart]
+ ) -> list[tinker.ModelInputChunk]:
+ chunks = []
+ for part in content:
+ if part["type"] == "text":
+ chunks.append(
+ tinker.types.EncodedTextChunk(
+ tokens=self.tokenizer.encode(part["text"])
+ )
+ )
+ elif part["type"] == "image":
+ assert self.image_processor is not None, (
+ "KimiK25Renderer must be initialized with an image processor in order to support image content parts"
+ )
+ chunks.append(
+ tinker.types.EncodedTextChunk(
+ tokens=self.tokenizer.encode(self._image_prefix)
+ )
+ )
+ chunks.append(
+ image_to_chunk(
+ part["image"],
+ cast(ImageProcessorProtocol, self.image_processor),
+ )
+ )
+ chunks.append(
+ tinker.types.EncodedTextChunk(
+ tokens=self.tokenizer.encode(self._image_suffix)
+ )
+ )
+ else:
+ raise ValueError(f"Unsupported content type: {part['type']}")
+ return chunks
+
+ @property
+ def _image_prefix(self) -> str:
+ return "<|media_begin|>image<|media_content|>"
+
+ @property
+ def _image_suffix(self) -> str:
+ return "<|media_end|>\n"
+
+ def build_generation_prompt(
+ self,
+ messages: list[Message],
+ role: Role = "assistant",
+ prefill: str | None = None,
+ ) -> tinker.ModelInput:
+ """Build generation prompt with prefill for thinking mode."""
+ # If no prefill specified, use to enable thinking
+ if prefill is None:
+ prefill = ""
+ return super().build_generation_prompt(messages, role=role, prefill=prefill)
+
+ def create_conversation_prefix_with_tools(
+ self, tools: list[ToolSpec], system_prompt: str = ""
+ ) -> list[Message]:
+ """Create system messages with TypeScript-style tool specifications.
+
+ Per the HuggingFace chat template, Kimi K2.5 uses TypeScript-style tool
+ declarations instead of JSON format. The tool_declare message comes BEFORE
+ the regular system message.
+
+ Reference: kimi-k2.5-hf-tokenizer/chat_template.jinja
+ """
+ messages: list[Message] = []
+
+ # Tool declaration message comes first (per HF chat template)
+ if tools:
+ tools_payload = [{"type": "function", "function": tool} for tool in tools]
+ tools_ts_str = encode_tools_to_typescript_style(tools_payload)
+ messages.append(Message(role="tool_declare", content=tools_ts_str))
+
+ # Regular system message second (use default if none provided)
+ actual_system_prompt = (
+ system_prompt if system_prompt else self.DEFAULT_SYSTEM_PROMPT
+ )
+ messages.append(Message(role="system", content=actual_system_prompt))
+
+ return messages
+
+
+class KimiK25DisableThinkingRenderer(KimiK25Renderer):
+ """
+ Renderer for Kimi K2.5 with thinking disabled.
+
+ Uses `` prefill instead of `` to disable thinking mode.
+
+ Format:
+ <|im_system|>system<|im_middle|>You are Kimi...<|im_end|>
+ <|im_user|>user<|im_middle|>Hello<|im_end|>
+ <|im_assistant|>assistant<|im_middle|>
+ """
+
+ def build_generation_prompt(
+ self,
+ messages: list[Message],
+ role: Role = "assistant",
+ prefill: str | None = None,
+ ) -> tinker.ModelInput:
+ """Build generation prompt with prefill to disable thinking."""
+ # If no prefill specified, use to disable thinking
+ if prefill is None:
+ prefill = ""
+ return super(KimiK25Renderer, self).build_generation_prompt(
+ messages, role=role, prefill=prefill
+ )
diff --git a/src/art/tinker/cookbook_v/renderers/kimi_k2_5_tool_declaration_ts.py b/src/art/tinker/cookbook_v/renderers/kimi_k2_5_tool_declaration_ts.py
new file mode 100644
index 00000000..7201bc5f
--- /dev/null
+++ b/src/art/tinker/cookbook_v/renderers/kimi_k2_5_tool_declaration_ts.py
@@ -0,0 +1,501 @@
+"""
+Encode structured tool declaration to typescript style string.
+
+Copied from kimi-k2.5-hf-tokenizer/tool_declaration_ts.py for Kimi K2.5 support.
+"""
+
+from collections.abc import Sequence
+import dataclasses
+import json
+import logging
+from typing import Any
+
+logger = logging.getLogger(__name__)
+
+_TS_INDENT = " "
+_TS_FIELD_DELIMITER = ",\n"
+
+
+class _SchemaRegistry:
+ """Registry for schema definitions to handle $ref resolution"""
+
+ def __init__(self):
+ self.definitions = {}
+ self.has_self_ref = False
+
+ def register_definitions(self, defs: dict[str, Any]):
+ """Register schema definitions from $defs section"""
+ if not defs:
+ return
+ for def_name, def_schema in defs.items():
+ self.definitions[def_name] = def_schema
+
+ def resolve_ref(self, ref: str) -> dict[str, Any]:
+ """Resolve a reference to its schema definition"""
+ if ref == "#":
+ self.has_self_ref = True
+ return {"$self_ref": True}
+ elif ref.startswith("#/$defs/"):
+ def_name = ref.split("/")[-1]
+ if def_name not in self.definitions:
+ raise ValueError(f"Reference not found: {ref}")
+ return self.definitions[def_name]
+ else:
+ raise ValueError(f"Unsupported reference format: {ref}")
+
+
+def _format_description(description: str, indent: str = "") -> str:
+ return "\n".join(
+ [f"{indent}// {line}" if line else "" for line in description.split("\n")]
+ )
+
+
+class _BaseType:
+ description: str
+ constraints: dict[str, Any]
+
+ def __init__(
+ self,
+ extra_props: dict[str, Any],
+ *,
+ allowed_constraint_keys: Sequence[str] = (),
+ ):
+ self.description = extra_props.get("description", "")
+ self.constraints = {
+ k: v for k, v in extra_props.items() if k in allowed_constraint_keys
+ }
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ raise NotImplementedError
+
+ def format_docstring(self, indent: str) -> str:
+ lines = []
+ if self.description:
+ lines.append(_format_description(self.description, indent))
+ if self.constraints:
+ constraints_str = ", ".join(
+ f"{k}: {v}"
+ for k, v in sorted(self.constraints.items(), key=lambda kv: kv[0])
+ )
+ lines.append(f"{indent}// {constraints_str}")
+
+ return "".join(x + "\n" for x in lines)
+
+
+class _ParameterTypeScalar(_BaseType):
+ type: str
+
+ def __init__(self, type: str, extra_props: dict[str, Any] | None = None):
+ self.type = type
+
+ allowed_constraint_keys: list[str] = []
+ if self.type == "string":
+ allowed_constraint_keys = ["maxLength", "minLength", "pattern"]
+ elif self.type in ("number", "integer"):
+ allowed_constraint_keys = ["maximum", "minimum"]
+
+ super().__init__(
+ extra_props or {}, allowed_constraint_keys=allowed_constraint_keys
+ )
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ # Map integer to number in TypeScript
+ if self.type == "integer":
+ return "number"
+ return self.type
+
+
+class _ParameterTypeObject(_BaseType):
+ properties: list["_Parameter"]
+ additional_properties: Any | None = None
+
+ def __init__(
+ self,
+ json_schema_object: dict[str, Any],
+ registry: _SchemaRegistry | None = None,
+ ):
+ super().__init__(json_schema_object)
+
+ self.properties = []
+ self.additional_properties = None
+
+ if not json_schema_object:
+ return
+
+ if "$defs" in json_schema_object and registry:
+ registry.register_definitions(json_schema_object["$defs"])
+
+ self.additional_properties = json_schema_object.get("additionalProperties")
+ if isinstance(self.additional_properties, dict):
+ self.additional_properties = _parse_parameter_type(
+ self.additional_properties, registry
+ )
+
+ if "properties" not in json_schema_object:
+ return
+
+ required_parameters = json_schema_object.get("required", [])
+ optional_parameters = set(json_schema_object["properties"].keys()) - set(
+ required_parameters
+ )
+
+ self.properties = [
+ _Parameter(
+ name=name,
+ type=_parse_parameter_type(prop, registry),
+ optional=name in optional_parameters,
+ default=prop.get("default") if isinstance(prop, dict) else None,
+ )
+ for name, prop in json_schema_object["properties"].items()
+ ]
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ # sort by optional, make the required parameters first
+ parameters = [p for p in self.properties if not p.optional]
+ opt_params = [p for p in self.properties if p.optional]
+
+ parameters = sorted(parameters, key=lambda p: p.name)
+ parameters.extend(sorted(opt_params, key=lambda p: p.name))
+
+ param_strs = []
+ for p in parameters:
+ one = p.to_typescript_style(indent=indent + _TS_INDENT)
+ param_strs.append(one)
+
+ if self.additional_properties is not None:
+ ap_type_str = "any"
+ if self.additional_properties is True:
+ ap_type_str = "any"
+ elif self.additional_properties is False:
+ ap_type_str = "never"
+ elif isinstance(self.additional_properties, _ParameterType):
+ ap_type_str = self.additional_properties.to_typescript_style(
+ indent=indent + _TS_INDENT
+ )
+ else:
+ raise ValueError(
+ f"Unknown additionalProperties: {self.additional_properties}"
+ )
+ param_strs.append(f"{indent + _TS_INDENT}[k: string]: {ap_type_str}")
+
+ if not param_strs:
+ return "{}"
+
+ params_str = _TS_FIELD_DELIMITER.join(param_strs)
+ if params_str:
+ # add new line before and after
+ params_str = f"\n{params_str}\n"
+ # always wrap with object
+ return f"{{{params_str}{indent}}}"
+
+
+class _ParameterTypeArray(_BaseType):
+ item: "_ParameterType"
+
+ def __init__(
+ self,
+ json_schema_object: dict[str, Any],
+ registry: _SchemaRegistry | None = None,
+ ):
+ super().__init__(
+ json_schema_object, allowed_constraint_keys=("minItems", "maxItems")
+ )
+ if json_schema_object.get("items"):
+ self.item = _parse_parameter_type(json_schema_object["items"], registry)
+ else:
+ self.item = _ParameterTypeScalar(type="any")
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ item_docstring = self.item.format_docstring(indent + _TS_INDENT)
+ if item_docstring:
+ return (
+ "Array<\n"
+ + item_docstring
+ + indent
+ + _TS_INDENT
+ + self.item.to_typescript_style(indent=indent + _TS_INDENT)
+ + "\n"
+ + indent
+ + ">"
+ )
+ else:
+ return f"Array<{self.item.to_typescript_style(indent=indent)}>"
+
+
+class _ParameterTypeEnum(_BaseType):
+ # support scalar types only
+ enum: list[str | int | float | bool | None]
+
+ def __init__(self, json_schema_object: dict[str, Any]):
+ super().__init__(json_schema_object)
+ self.enum = json_schema_object["enum"]
+
+ # Validate enum values against declared type if present
+ if "type" in json_schema_object:
+ typ = json_schema_object["type"]
+ if isinstance(typ, list):
+ if len(typ) == 1:
+ typ = typ[0]
+ elif len(typ) == 2:
+ if "null" not in typ:
+ raise ValueError(f"Enum type {typ} is not supported")
+ else:
+ typ = typ[0] if typ[0] != "null" else typ[1]
+ else:
+ raise ValueError(f"Enum type {typ} is not supported")
+ for val in self.enum:
+ if val is None:
+ continue
+ if typ == "string" and not isinstance(val, str):
+ raise ValueError(f"Enum value {val} is not a string")
+ elif typ == "number" and not isinstance(val, (int, float)):
+ raise ValueError(f"Enum value {val} is not a number")
+ elif typ == "integer" and not isinstance(val, int):
+ raise ValueError(f"Enum value {val} is not an integer")
+ elif typ == "boolean" and not isinstance(val, bool):
+ raise ValueError(f"Enum value {val} is not a boolean")
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ return " | ".join(
+ [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum]
+ )
+
+
+class _ParameterTypeAnyOf(_BaseType):
+ types: list["_ParameterType"]
+
+ def __init__(
+ self,
+ json_schema_object: dict[str, Any],
+ registry: _SchemaRegistry | None = None,
+ ):
+ super().__init__(json_schema_object)
+ self.types = [
+ _parse_parameter_type(t, registry) for t in json_schema_object["anyOf"]
+ ]
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ return " | ".join([t.to_typescript_style(indent=indent) for t in self.types])
+
+
+class _ParameterTypeUnion(_BaseType):
+ types: list[str]
+
+ def __init__(self, json_schema_object: dict[str, Any]):
+ super().__init__(json_schema_object)
+
+ mapping = {
+ "string": "string",
+ "number": "number",
+ "integer": "number",
+ "boolean": "boolean",
+ "null": "null",
+ "object": "{}",
+ "array": "Array",
+ }
+ self.types = [mapping[t] for t in json_schema_object["type"]]
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ return " | ".join(self.types)
+
+
+class _ParameterTypeRef(_BaseType):
+ ref_name: str
+ is_self_ref: bool = False
+
+ def __init__(self, json_schema_object: dict[str, Any], registry: _SchemaRegistry):
+ super().__init__(json_schema_object)
+
+ ref = json_schema_object["$ref"]
+ resolved_schema = registry.resolve_ref(ref)
+
+ if resolved_schema.get("$self_ref", False):
+ self.ref_name = "parameters"
+ self.is_self_ref = True
+ else:
+ self.ref_name = ref.split("/")[-1]
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ return self.ref_name
+
+
+_ParameterType = (
+ _ParameterTypeScalar
+ | _ParameterTypeObject
+ | _ParameterTypeArray
+ | _ParameterTypeEnum
+ | _ParameterTypeAnyOf
+ | _ParameterTypeUnion
+ | _ParameterTypeRef
+)
+
+
+@dataclasses.dataclass
+class _Parameter:
+ """
+ A parameter in a function, or a field in a object.
+ It consists of the type as well as the name.
+ """
+
+ type: _ParameterType
+ name: str = "_"
+ optional: bool = True
+ default: Any | None = None
+
+ @classmethod
+ def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter":
+ if not attributes:
+ raise ValueError("attributes is empty")
+
+ return cls(
+ name=attributes.get("name", "_"),
+ type=_parse_parameter_type(attributes),
+ optional=attributes.get("optional", False),
+ default=attributes.get("default"),
+ )
+
+ def to_typescript_style(self, indent: str = "") -> str:
+ comments = self.type.format_docstring(indent)
+
+ if self.default is not None:
+ default_repr = (
+ json.dumps(self.default, ensure_ascii=False)
+ if not isinstance(self.default, (int, float, bool))
+ else repr(self.default)
+ )
+ comments += f"{indent}// Default: {default_repr}\n"
+
+ return (
+ comments
+ + f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}"
+ )
+
+
+def _parse_parameter_type(
+ json_schema_object: dict[str, Any] | bool, registry: _SchemaRegistry | None = None
+) -> _ParameterType:
+ if isinstance(json_schema_object, bool):
+ if json_schema_object:
+ return _ParameterTypeScalar(type="any")
+ else:
+ logger.warning(
+ f"Warning: Boolean value {json_schema_object} is not supported, use null instead."
+ )
+ return _ParameterTypeScalar(type="null")
+
+ if "$ref" in json_schema_object and registry:
+ return _ParameterTypeRef(json_schema_object, registry)
+
+ if "anyOf" in json_schema_object:
+ return _ParameterTypeAnyOf(json_schema_object, registry)
+ elif "enum" in json_schema_object:
+ return _ParameterTypeEnum(json_schema_object)
+ elif "type" in json_schema_object:
+ typ = json_schema_object["type"]
+ if isinstance(typ, list):
+ return _ParameterTypeUnion(json_schema_object)
+ elif typ == "object":
+ return _ParameterTypeObject(json_schema_object, registry)
+ elif typ == "array":
+ return _ParameterTypeArray(json_schema_object, registry)
+ else:
+ return _ParameterTypeScalar(typ, json_schema_object)
+ elif json_schema_object == {}:
+ return _ParameterTypeScalar(type="any")
+ else:
+ raise ValueError(f"Invalid JSON Schema object: {json_schema_object}")
+
+
+def _openai_function_to_typescript_style(
+ function: dict[str, Any],
+) -> str:
+ """Convert OpenAI function definition (dict) to TypeScript style string."""
+ registry = _SchemaRegistry()
+ parameters = function.get("parameters") or {}
+ parsed = _ParameterTypeObject(parameters, registry)
+
+ interfaces = []
+ root_interface_name = None
+ if registry.has_self_ref:
+ root_interface_name = "parameters"
+ params_str = _TS_FIELD_DELIMITER.join(
+ [p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties]
+ )
+ params_str = f"\n{params_str}\n" if params_str else ""
+ interface_def = f"interface {root_interface_name} {{{params_str}}}"
+ interfaces.append(interface_def)
+
+ definitions_copy = dict(registry.definitions)
+ for def_name, def_schema in definitions_copy.items():
+ obj_type = _parse_parameter_type(def_schema, registry)
+ params_str = obj_type.to_typescript_style()
+
+ description_part = ""
+ if obj_description := def_schema.get("description", ""):
+ description_part = _format_description(obj_description) + "\n"
+
+ interface_def = f"{description_part}interface {def_name} {params_str}"
+ interfaces.append(interface_def)
+
+ interface_str = "\n".join(interfaces)
+ function_name = function.get("name", "function")
+ if root_interface_name:
+ type_def = f"type {function_name} = (_: {root_interface_name}) => any;"
+ else:
+ params_str = parsed.to_typescript_style()
+ type_def = f"type {function_name} = (_: {params_str}) => any;"
+
+ description = function.get("description")
+ return "\n".join(
+ filter(
+ bool,
+ [
+ interface_str,
+ ((description and _format_description(description)) or ""),
+ type_def,
+ ],
+ )
+ )
+
+
+def encode_tools_to_typescript_style(
+ tools: list[dict[str, Any]],
+) -> str:
+ """
+ Convert tools (list of dict) to TypeScript style string.
+
+ Supports OpenAI format: {"type": "function", "function": {...}}
+
+ Args:
+ tools: List of tool definitions in dict format
+
+ Returns:
+ TypeScript style string representation of the tools
+ """
+ if not tools:
+ return ""
+
+ functions = []
+
+ for tool in tools:
+ tool_type = tool.get("type")
+ if tool_type == "function":
+ func_def = tool.get("function", {})
+ if func_def:
+ functions.append(_openai_function_to_typescript_style(func_def))
+ else:
+ # Skip unsupported tool types (like "_plugin")
+ continue
+
+ if not functions:
+ return ""
+
+ functions_str = "\n".join(functions)
+ result = "# Tools\n\n"
+
+ if functions_str:
+ result += "## functions\nnamespace functions {\n"
+ result += functions_str + "\n"
+ result += "}\n"
+
+ return result
diff --git a/src/art/tinker/cookbook_v/renderers/qwen3.py b/src/art/tinker/cookbook_v/renderers/qwen3.py
index 343a6ece..6e0fe3e4 100644
--- a/src/art/tinker/cookbook_v/renderers/qwen3.py
+++ b/src/art/tinker/cookbook_v/renderers/qwen3.py
@@ -24,6 +24,7 @@
RenderedMessage,
Renderer,
TextPart,
+ ToolCall,
ToolSpec,
UnparsedToolCall,
_tool_call_payload,
@@ -152,7 +153,6 @@ def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessag
rendered_parts.append(f"{p['thinking']}")
elif p["type"] == "text":
rendered_parts.append(p["text"])
- # ToolCallPart handled via message's tool_calls field
output_content = "".join(rendered_parts)
else:
# String content - pass through as-is.
@@ -211,23 +211,16 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
content = assistant_message["content"]
# Parse all blocks in one pass, preserving order
- parts = parse_content_blocks(content)
+ result = parse_content_blocks(content)
- if parts is not None:
+ if result is not None:
+ parts, tool_results = result
assistant_message["content"] = parts
- # Also populate tool_calls and unparsed_tool_calls fields for backward compatibility
- # TODO: Consider moving away from TypedDicts for part types - current approach
- # relies on runtime type checking (p["type"] == "tool_call") without static guarantees.
- tool_calls = [p["tool_call"] for p in parts if p["type"] == "tool_call"]
+ tool_calls = [t for t in tool_results if isinstance(t, ToolCall)]
+ unparsed = [t for t in tool_results if isinstance(t, UnparsedToolCall)]
if tool_calls:
assistant_message["tool_calls"] = tool_calls
-
- unparsed = [
- UnparsedToolCall(raw_text=p["raw_text"], error=p["error"])
- for p in parts
- if p["type"] == "unparsed_tool_call"
- ]
if unparsed:
assistant_message["unparsed_tool_calls"] = unparsed
else:
@@ -273,7 +266,9 @@ def to_openai_message(self, message: Message) -> dict:
"id": tc.id,
"function": {
"name": tc.function.name,
- "arguments": tc.function.arguments,
+ "arguments": self._to_openai_tool_arguments(
+ tc.function.arguments
+ ),
},
}
for tc in message["tool_calls"]
@@ -288,6 +283,14 @@ def to_openai_message(self, message: Message) -> dict:
return result
+ def _to_openai_tool_arguments(self, arguments: str) -> str | dict:
+ """Convert tool arguments for OpenAI-compatible message payloads.
+
+ Qwen3 templates accept JSON-string arguments directly; subclasses can
+ override to return dicts for templates that iterate over arguments.
+ """
+ return arguments
+
def create_conversation_prefix_with_tools(
self, tools: list[ToolSpec], system_prompt: str = ""
) -> list[Message]:
@@ -306,7 +309,8 @@ def create_conversation_prefix_with_tools(
# Use separators=(", ", ": ") to match HF's tojson filter output
tool_lines = "\n".join(
json.dumps(
- {"type": "function", "function": tool}, separators=(", ", ": ")
+ {"type": "function", "function": tool},
+ separators=(", ", ": "),
)
for tool in tools
)
@@ -415,12 +419,12 @@ class Qwen3VLRenderer(Qwen3Renderer):
The default strip_thinking_from_history=True matches the non-VL Qwen3Renderer behavior.
"""
- image_processor: ImageProcessor
+ image_processor: ImageProcessor | None
def __init__(
self,
tokenizer: Tokenizer,
- image_processor: ImageProcessor,
+ image_processor: ImageProcessor | None = None,
strip_thinking_from_history: bool = True,
merge_text_chunks: bool = True,
):
@@ -429,14 +433,21 @@ def __init__(
self.strip_thinking_from_history = strip_thinking_from_history
self.merge_text_chunks = merge_text_chunks
+ def _format_thinking_text(self, thinking: str) -> str:
+ """Format a ThinkingPart payload for rendering."""
+ return f"{thinking}"
+
+ def _assistant_header_suffix(self, message: Message, ctx: RenderContext) -> str:
+ """Additional assistant header text injected before content."""
+ return ""
+
def _preprocess_message_parts(
self, message: Message, *, strip_thinking: bool = False
) -> list[ImagePart | TextPart]:
"""Convert message content to list form for VL rendering.
Converts ThinkingPart to ... text (or strips if strip_thinking=True).
- Wraps images with vision tokens. ToolCallPart is not supported in VL content list
- (use message's tool_calls field instead).
+ Wraps images with vision tokens. Tool calls are in message's tool_calls field.
"""
content = message["content"]
if isinstance(content, str):
@@ -456,11 +467,11 @@ def _preprocess_message_parts(
# Render thinking as ... text
base_parts.append(
TextPart(
- type="text", text=f"{p['thinking']}"
+ type="text",
+ text=self._format_thinking_text(p["thinking"]),
)
)
# else: strip thinking by not appending
- # ToolCallPart and UnparsedToolCallPart are handled via message's tool_calls field
# Wrap images with vision tokens
chunks: list[ImagePart | TextPart] = []
@@ -485,11 +496,30 @@ def _wrap_qwen_tool_response_chunks(
+ [TextPart(type="text", text="\n")]
)
+ def _format_tool_calls_chunks(self, message: Message) -> list[ImagePart | TextPart]:
+ """Format tool_calls as output chunks. Override in subclasses for different formats."""
+ # Add leading newline to match HF template behavior
+ assert "tool_calls" in message, "tool_calls are required to format tool calls"
+ return [
+ TextPart(
+ type="text",
+ text="\n"
+ + "\n".join(
+ [
+ f"\n{json.dumps(_tool_call_payload(tool_call))}\n"
+ for tool_call in message["tool_calls"]
+ ]
+ ),
+ )
+ ]
+
def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessage:
maybe_newline = "\n" if ctx.idx > 0 else ""
role = self._get_qwen_role_for_message(message)
header_str = f"{maybe_newline}<|im_start|>{role}\n"
+ if message["role"] == "assistant":
+ header_str += self._assistant_header_suffix(message, ctx)
# Strip thinking from history for non-last assistant messages (matching non-VL behavior)
strip_thinking = (
@@ -506,35 +536,34 @@ def render_message(self, message: Message, ctx: RenderContext) -> RenderedMessag
output_chunks = self._wrap_qwen_tool_response_chunks(output_chunks)
if "tool_calls" in message:
- # Add leading newline to match HF template behavior
- output_chunks += [
- TextPart(
- type="text",
- text="\n"
- + "\n".join(
- [
- f"\n{json.dumps(_tool_call_payload(tool_call))}\n"
- for tool_call in message["tool_calls"]
- ]
- ),
- )
- ]
+ output_chunks += self._format_tool_calls_chunks(message)
output_chunks += [TextPart(type="text", text="<|im_end|>")]
if self.merge_text_chunks:
output_chunks = _merge_consecutive_text_parts(output_chunks)
- output_chunks_encoded: list[tinker.ModelInputChunk] = [
- image_to_chunk(
- image_or_str=x["image"],
- image_processor=cast(ImageProcessorProtocol, self.image_processor),
- )
- if x["type"] == "image"
- else tinker.EncodedTextChunk(
- tokens=self.tokenizer.encode(x["text"], add_special_tokens=False)
- )
- for x in output_chunks
- ]
+ output_chunks_encoded: list[tinker.ModelInputChunk] = []
+ for x in output_chunks:
+ if x["type"] == "image":
+ assert self.image_processor is not None, (
+ "image_processor is required to render image content"
+ )
+ output_chunks_encoded.append(
+ image_to_chunk(
+ image_or_str=x["image"],
+ image_processor=cast(
+ ImageProcessorProtocol, self.image_processor
+ ),
+ )
+ )
+ else:
+ output_chunks_encoded.append(
+ tinker.EncodedTextChunk(
+ tokens=self.tokenizer.encode(
+ x["text"], add_special_tokens=False
+ )
+ )
+ )
header = tinker.types.EncodedTextChunk(
tokens=self.tokenizer.encode(header_str, add_special_tokens=False)
diff --git a/src/art/tinker/cookbook_v/renderers/qwen3_5.py b/src/art/tinker/cookbook_v/renderers/qwen3_5.py
new file mode 100644
index 00000000..0961b708
--- /dev/null
+++ b/src/art/tinker/cookbook_v/renderers/qwen3_5.py
@@ -0,0 +1,306 @@
+"""
+Qwen3.5 family renderer.
+
+Qwen3.5 models are VL models with the same basic
+chat format as Qwen3-VL (im_start/im_end, thinking, vision tokens) but with a
+different tool calling format.
+
+Tool calling differences from Qwen3:
+- Qwen3: JSON format {"name": ..., "arguments": ...}
+- Qwen3.5: XML format value
+
+Unlike Qwen3, the Qwen3.5 HF template:
+- Always adds ... blocks to assistant messages after the last user
+ message (empty if no reasoning content).
+- Always adds \\n to the generation prompt.
+
+Reference: https://huggingface.co/Qwen/Qwen3.5-4B/blob/main/tokenizer_config.json
+"""
+
+import json
+import re
+
+from .base import (
+ ImagePart,
+ Message,
+ RenderContext,
+ Role,
+ TextPart,
+ ToolCall,
+ ToolSpec,
+ UnparsedToolCall,
+)
+from .qwen3 import Qwen3VLRenderer
+
+_FUNCTION_BLOCK_RE = re.compile(
+ r"^\s*\s*[^>\n]+)>\s*(?P.*?)\s*\s*\s*$",
+ re.DOTALL,
+)
+_PARAM_BLOCK_RE = re.compile(
+ r"[^>\n]+)>\s*(?P.*?)\s*",
+ re.DOTALL,
+)
+
+
+class Qwen3_5Renderer(Qwen3VLRenderer):
+ """
+ Renderer for Qwen3.5 models.
+
+ Subclasses Qwen3VLRenderer since Qwen3.5 models are VL models sharing the same
+ basic chat format. Overrides tool calling to use Qwen3.5's XML parameter format.
+
+ The Qwen3.5 HF template adds empty blocks to assistant messages after
+ the last user message. This is handled via ctx.last_user_index, which is
+ populated by the base build_generation_prompt/build_supervised_example.
+ """
+
+ def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:
+ """Override to produce the full generation suffix directly.
+
+ Builds the header tokens manually and appends \\n. This matches
+ the Qwen3.5 template's add_generation_prompt behavior for thinking mode.
+ """
+ maybe_newline = "\n" if ctx.idx > 0 else ""
+ header_str = f"{maybe_newline}<|im_start|>{role}\n\n"
+ return self.tokenizer.encode(header_str, add_special_tokens=False)
+
+ def _assistant_header_suffix(self, message: Message, ctx: RenderContext) -> str:
+ """Insert empty think block for assistant messages after the last user query."""
+ if ctx.idx <= ctx.last_user_index:
+ return ""
+
+ content = message.get("content", "")
+ has_think = False
+ if isinstance(content, list):
+ has_think = any(p["type"] == "thinking" for p in content)
+ elif isinstance(content, str):
+ has_think = "" in content
+
+ return "" if has_think else "\n\n\n\n"
+
+ def _format_thinking_text(self, thinking: str) -> str:
+ """Qwen3.5 uses newline-padded think blocks."""
+ return f"\n{thinking}\n\n\n"
+
+ def _to_openai_tool_arguments(self, arguments: str) -> str | dict:
+ """Qwen3.5 chat template expects arguments as a mapping for |items."""
+ return json.loads(arguments)
+
+ def _parse_qwen3_5_tool_call_xml(
+ self, raw_text: str
+ ) -> ToolCall | UnparsedToolCall:
+ """Parse Qwen3.5 XML-style tool calls from a raw block."""
+ match = _FUNCTION_BLOCK_RE.match(raw_text)
+ if not match:
+ return UnparsedToolCall(
+ raw_text=raw_text, error="Malformed Qwen3.5 tool call XML"
+ )
+
+ function_name = match.group("name").strip()
+ body = match.group("body")
+ if not function_name:
+ return UnparsedToolCall(raw_text=raw_text, error="Missing function name")
+
+ arguments: dict[str, object] = {}
+ pos = 0
+ for param in _PARAM_BLOCK_RE.finditer(body):
+ if body[pos : param.start()].strip():
+ return UnparsedToolCall(
+ raw_text=raw_text,
+ error="Unexpected non-parameter content inside block",
+ )
+
+ param_name = param.group("name").strip()
+ param_value_text = param.group("value").strip("\n")
+ if not param_name:
+ return UnparsedToolCall(raw_text=raw_text, error="Empty parameter name")
+
+ try:
+ param_value: object = json.loads(param_value_text)
+ except json.JSONDecodeError:
+ param_value = param_value_text
+
+ arguments[param_name] = param_value
+ pos = param.end()
+
+ if body[pos:].strip():
+ return UnparsedToolCall(
+ raw_text=raw_text,
+ error="Unexpected trailing content inside block",
+ )
+
+ return ToolCall(
+ function=ToolCall.FunctionBody(
+ name=function_name,
+ arguments=json.dumps(arguments),
+ )
+ )
+
+ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
+ """Parse response, prepending \\n since the generation prompt prefills it.
+
+ When sampling with build_generation_prompt, \\n is part of the generation
+ suffix and not included in the sampled tokens. The response will be
+ "reasoning\\n\\n\\nanswer" so we prepend \\n if necessary.
+
+ Also strips leading/trailing whitespace from thinking content to match the
+ HF template behavior (which applies |trim to reasoning_content).
+ """
+ think_prefix_tokens = self.tokenizer.encode(
+ "\n", add_special_tokens=False
+ )
+ think_suffix_token = self.tokenizer.encode("", add_special_tokens=False)
+ assert len(think_suffix_token) == 1
+
+ starts_with_think = (
+ len(response) >= len(think_prefix_tokens)
+ and response[: len(think_prefix_tokens)] == think_prefix_tokens
+ )
+ if not starts_with_think and think_suffix_token[0] in response:
+ response = think_prefix_tokens + response
+
+ message, success = super().parse_response(response)
+ if not success:
+ return message, success
+
+ # Strip whitespace from thinking content (matches HF template |trim behavior)
+ content = message.get("content")
+ if isinstance(content, list):
+ first_text_after_thinking: TextPart | None = None
+ seen_thinking = False
+ for p in content:
+ if p["type"] == "thinking":
+ p["thinking"] = p["thinking"].strip()
+ seen_thinking = True
+ elif seen_thinking and p["type"] == "text":
+ first_text_after_thinking = p
+ break
+
+ # Template inserts exactly two separator newlines between and text.
+ if first_text_after_thinking is not None and first_text_after_thinking[
+ "text"
+ ].startswith("\n\n"):
+ first_text_after_thinking["text"] = first_text_after_thinking["text"][
+ 2:
+ ]
+
+ # Qwen3 parent parser assumes JSON inside ; convert XML blocks here.
+ converted_xml_calls: list[ToolCall] = []
+ remaining_unparsed: list[UnparsedToolCall] = []
+ for unparsed in message.get("unparsed_tool_calls", []):
+ if " str:
+ """Format a single tool call in Qwen3.5's XML parameter format."""
+ args = (
+ json.loads(tool_call.function.arguments)
+ if tool_call.function.arguments
+ else {}
+ )
+ lines = [f"\n"]
+ for param_name, param_value in args.items():
+ if isinstance(param_value, (dict, list)):
+ value_str = json.dumps(param_value)
+ else:
+ value_str = str(param_value)
+ lines.append(f"\n{value_str}\n")
+ lines.append("\n")
+ return "\n".join(lines)
+
+ def _format_tool_calls_chunks(self, message: Message) -> list[ImagePart | TextPart]:
+ """Format tool_calls using Qwen3.5's XML parameter format."""
+ assert "tool_calls" in message, "tool_calls are required to format tool calls"
+ return [
+ TextPart(
+ type="text",
+ text="\n\n"
+ + "\n".join(
+ self._format_tool_call_xml(tc) for tc in message["tool_calls"]
+ ),
+ )
+ ]
+
+ def create_conversation_prefix_with_tools(
+ self, tools: list[ToolSpec], system_prompt: str = ""
+ ) -> list[Message]:
+ """Create system message with Qwen3.5 tool specifications.
+
+ Qwen3.5 uses a different tool declaration format from Qwen3, with XML-based
+ function/parameter calling syntax.
+
+ Reference: https://huggingface.co/Qwen/Qwen3.5-4B/blob/main/tokenizer_config.json
+ """
+ tools_text = ""
+ if tools:
+ tool_lines = "\n".join(json.dumps(tool) for tool in tools)
+ tools_text = (
+ "# Tools\n\n"
+ "You have access to the following functions:\n\n"
+ "\n"
+ f"{tool_lines}\n"
+ "\n\n"
+ "If you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
+ "\n"
+ "\n"
+ "\n"
+ "value_1\n"
+ "\n"
+ "\n"
+ "This is the value for the second parameter\n"
+ "that can span\n"
+ "multiple lines\n"
+ "\n"
+ "\n"
+ "\n\n"
+ "\n"
+ "Reminder:\n"
+ "- Function calls MUST follow the specified format: "
+ "an inner block must be nested within "
+ " XML tags\n"
+ "- Required parameters MUST be specified\n"
+ "- You may provide optional reasoning for your function call in natural language "
+ "BEFORE the function call, but NOT after\n"
+ "- If there is no function call available, answer the question like normal with "
+ "your current knowledge and do not tell the user about function calls\n"
+ ""
+ )
+
+ if tools_text:
+ content = (
+ tools_text + "\n\n" + system_prompt if system_prompt else tools_text
+ )
+ else:
+ content = system_prompt
+
+ return [Message(role="system", content=content)]
+
+
+class Qwen3_5DisableThinkingRenderer(Qwen3_5Renderer):
+ """
+ Renderer for Qwen3.5 models with thinking disabled.
+
+ Matches the Qwen3.5 HF template with enable_thinking=False. The only difference
+ from Qwen3_5Renderer is the generation suffix: \\n\\n\\n\\n instead
+ of \\n, signaling to the model to respond directly without reasoning.
+ """
+
+ def _get_generation_suffix(self, role: Role, ctx: RenderContext) -> list[int]:
+ maybe_newline = "\n" if ctx.idx > 0 else ""
+ header_str = f"{maybe_newline}<|im_start|>{role}\n\n\n\n\n"
+ return self.tokenizer.encode(header_str, add_special_tokens=False)
diff --git a/src/art/tinker/cookbook_v/renderers/role_colon.py b/src/art/tinker/cookbook_v/renderers/role_colon.py
index 8f384f8d..595a7071 100644
--- a/src/art/tinker/cookbook_v/renderers/role_colon.py
+++ b/src/art/tinker/cookbook_v/renderers/role_colon.py
@@ -57,6 +57,15 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
logger = logging.getLogger(__name__)
+ # Strip EOS token from the end if present (base models may terminate with EOS
+ # instead of the expected stop sequence). We still return False for parse success
+ # since the model didn't produce the expected stop sequence.
+ terminated_with_eos = False
+ eos_token_id = self.tokenizer.eos_token_id
+ if eos_token_id is not None and response and response[-1] == eos_token_id:
+ response = response[:-1]
+ terminated_with_eos = True
+
str_response = self.tokenizer.decode(response)
splitted = str_response.split("\n\nUser:")
if len(splitted) == 1:
@@ -64,12 +73,17 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
return Message(role="assistant", content=str_response.strip()), False
elif len(splitted) == 2:
before, _after = splitted
- return Message(role="assistant", content=before.strip()), True
+ return Message(
+ role="assistant", content=before.strip()
+ ), not terminated_with_eos
else:
- raise ValueError(
- f"When parsing response, expected to split into 1 or 2 pieces using stop tokens, but got {len(splitted)}. "
- "You probably are using the wrong stop tokens when sampling"
+ logger.warning(
+ "RoleColonRenderer.parse_response saw multiple stop delimiters "
+ "(count=%d). Returning parse_success=False. Full response:\n%s",
+ len(splitted) - 1,
+ str_response,
)
+ return Message(role="assistant", content=splitted[0].strip()), False
@property
def _bos_tokens(self) -> list[int]:
diff --git a/src/art/tinker/cookbook_v/tokenizer_utils.py b/src/art/tinker/cookbook_v/tokenizer_utils.py
index 53d91fef..f55d801e 100644
--- a/src/art/tinker/cookbook_v/tokenizer_utils.py
+++ b/src/art/tinker/cookbook_v/tokenizer_utils.py
@@ -7,6 +7,7 @@
from __future__ import annotations
+from collections.abc import Callable
from functools import cache
from typing import TYPE_CHECKING, Any, TypeAlias
@@ -19,9 +20,68 @@
# make it importable from other files as a type in runtime
Tokenizer: TypeAlias = Any
+# Global registry for custom tokenizer factories
+_CUSTOM_TOKENIZER_REGISTRY: dict[str, Callable[[], Tokenizer]] = {}
+
+
+def register_tokenizer(
+ name: str,
+ factory: Callable[[], Tokenizer],
+) -> None:
+ """Register a custom tokenizer factory.
+
+ Args:
+ name: The tokenizer name
+ factory: A callable that takes no arguments and returns a Tokenizer.
+
+ Example:
+ def my_tokenizer_factory():
+ return MyCustomTokenizer()
+
+ register_tokenizer("Foo/foo_tokenizer", my_tokenizer_factory)
+ """
+ _CUSTOM_TOKENIZER_REGISTRY[name] = factory
+
+
+def get_registered_tokenizer_names() -> list[str]:
+ """Return a list of all registered custom tokenizer names."""
+ return list(_CUSTOM_TOKENIZER_REGISTRY.keys())
+
+
+def is_tokenizer_registered(name: str) -> bool:
+ """Check if a tokenizer name is registered."""
+ return name in _CUSTOM_TOKENIZER_REGISTRY
+
+
+def unregister_tokenizer(name: str) -> bool:
+ """Unregister a custom tokenizer factory.
+
+ Args:
+ name: The tokenizer name to unregister.
+
+ Returns:
+ True if the tokenizer was unregistered, False if it wasn't registered.
+ """
+ if name in _CUSTOM_TOKENIZER_REGISTRY:
+ del _CUSTOM_TOKENIZER_REGISTRY[name]
+ return True
+ return False
+
-@cache
def get_tokenizer(model_name: str) -> Tokenizer:
+ """Get a tokenizer by name.
+
+ Checks custom registry first, then falls back to HuggingFace AutoTokenizer.
+ """
+ # Check custom registry first (not cached, factory handles caching if needed)
+ if (tokenizer := _CUSTOM_TOKENIZER_REGISTRY.get(model_name)) is not None:
+ return tokenizer()
+
+ return _get_hf_tokenizer(model_name)
+
+
+@cache
+def _get_hf_tokenizer(model_name: str) -> Tokenizer:
from transformers.models.auto.tokenization_auto import AutoTokenizer
model_name = model_name.split(":")[0]
@@ -33,6 +93,9 @@ def get_tokenizer(model_name: str) -> Tokenizer:
kwargs: dict[str, Any] = {}
if model_name == "moonshotai/Kimi-K2-Thinking":
kwargs["trust_remote_code"] = True
- kwargs["revision"] = "612681931a8c906ddb349f8ad0f582cb552189cd"
+ kwargs["revision"] = "a51ccc050d73dab088bf7b0e2dd9b30ae85a4e55"
+ elif model_name == "moonshotai/Kimi-K2.5":
+ kwargs["trust_remote_code"] = True
+ kwargs["revision"] = "2426b45b6af0da48d0dcce71bbce6225e5c73adc"
return AutoTokenizer.from_pretrained(model_name, use_fast=True, **kwargs)
diff --git a/src/art/tinker/renderers.py b/src/art/tinker/renderers.py
index 527c90ad..f49bd1e0 100644
--- a/src/art/tinker/renderers.py
+++ b/src/art/tinker/renderers.py
@@ -1,6 +1,10 @@
def get_renderer_name(base_model: str) -> str:
if base_model.startswith("meta-llama/"):
return "llama3"
+ elif base_model.startswith("Qwen/Qwen3.5-"):
+ print("Defaulting to Qwen3.5 renderer with thinking for", base_model)
+ print(renderer_name_message)
+ return "qwen3_5"
elif base_model.startswith("Qwen/Qwen3-"):
if "Instruct" in base_model:
return "qwen3_instruct"
@@ -8,6 +12,14 @@ def get_renderer_name(base_model: str) -> str:
print("Defaulting to Qwen3 renderer without thinking for", base_model)
print(renderer_name_message)
return "qwen3_disable_thinking"
+ elif base_model.startswith("moonshotai/Kimi-K2.5"):
+ print("Defaulting to Kimi K2.5 renderer with thinking for", base_model)
+ print(renderer_name_message)
+ return "kimi_k25"
+ elif base_model.startswith("moonshotai/Kimi-K2"):
+ print("Defaulting to Kimi K2 renderer with thinking for", base_model)
+ print(renderer_name_message)
+ return "kimi_k2"
elif base_model.startswith("deepseek-ai/DeepSeek-V3"):
print("Defaulting to DeepSeekV3 renderer without thinking for", base_model)
print(renderer_name_message)
@@ -34,12 +46,21 @@ def get_renderer_name(base_model: str) -> str:
Valid renderer names are:
+- role_colon
- llama3
- qwen3
+- qwen3_vl
+- qwen3_vl_instruct
- qwen3_disable_thinking
- qwen3_instruct
+- qwen3_5
+- qwen3_5_disable_thinking
- deepseekv3
- deepseekv3_disable_thinking
+- deepseekv3_thinking
+- kimi_k2
+- kimi_k25
+- kimi_k25_disable_thinking
- gpt_oss_no_sysprompt
- gpt_oss_low_reasoning
- gpt_oss_medium_reasoning
diff --git a/src/art/tinker/server.py b/src/art/tinker/server.py
index 32f41ca1..8be7e780 100644
--- a/src/art/tinker/server.py
+++ b/src/art/tinker/server.py
@@ -1,6 +1,7 @@
import asyncio
from dataclasses import dataclass, field
from itertools import cycle
+import json
import os
import socket
import time
@@ -102,7 +103,13 @@ async def chat_completion_and_token_discrepancies(
id=tool_call.get("id") or "",
function=Function(
name=tool_call["function"]["name"],
- arguments=tool_call["function"]["arguments"],
+ arguments=(
+ tool_call["function"]["arguments"]
+ if isinstance(
+ tool_call["function"]["arguments"], str
+ )
+ else json.dumps(tool_call["function"]["arguments"])
+ ),
),
)
for tool_call in openai_message.get("tool_calls", [])
@@ -160,6 +167,7 @@ def _get_renderer(self, base_model: str) -> renderers.Renderer:
self._renderers[base_model] = renderers.get_renderer(
name=get_renderer_name(base_model),
tokenizer=get_tokenizer(base_model),
+ model_name=base_model,
)
return self._renderers[base_model]
diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py
index ef88fef9..c1687bf7 100644
--- a/src/art/tinker_native/backend.py
+++ b/src/art/tinker_native/backend.py
@@ -2,6 +2,7 @@
import asyncio
from dataclasses import dataclass
+import json
import os
import re
import time
@@ -436,7 +437,13 @@ async def chat_completions(body: CompletionCreateParams) -> ChatCompletion:
id=tool_call.get("id") or f"call_{idx}",
function=Function(
name=tool_call["function"]["name"],
- arguments=tool_call["function"]["arguments"],
+ arguments=(
+ tool_call["function"]["arguments"]
+ if isinstance(
+ tool_call["function"]["arguments"], str
+ )
+ else json.dumps(tool_call["function"]["arguments"])
+ ),
),
)
for idx, tool_call in enumerate(parsed_message["tool_calls"])
@@ -526,6 +533,7 @@ async def _build_model_state(self, model: TrainableModel) -> ModelState:
renderer = renderers.get_renderer(
name=config.renderer_name,
tokenizer=tokenizer,
+ model_name=model.base_model,
)
saved_state = model.read_state() or {}
diff --git a/src/art/tinker_native/data.py b/src/art/tinker_native/data.py
index c4386d5f..6b29bcea 100644
--- a/src/art/tinker_native/data.py
+++ b/src/art/tinker_native/data.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import json
from typing import Any, Iterable, cast
from openai.types.chat.chat_completion import Choice
@@ -79,12 +80,15 @@ def convert_openai_messages_to_renderer_format(
tool_calls = []
for tool_call in msg["tool_calls"]:
func = tool_call.get("function", {})
+ arguments = func.get("arguments", "{}")
+ if not isinstance(arguments, str):
+ arguments = json.dumps(arguments)
tool_calls.append(
renderers.ToolCall(
id=tool_call.get("id", ""),
function=renderers.ToolCall.FunctionBody(
name=func.get("name", ""),
- arguments=func.get("arguments", "{}"),
+ arguments=arguments,
),
)
)
diff --git a/tests/unit/test_tinker_renderers.py b/tests/unit/test_tinker_renderers.py
new file mode 100644
index 00000000..9d388449
--- /dev/null
+++ b/tests/unit/test_tinker_renderers.py
@@ -0,0 +1,196 @@
+import json
+from typing import cast
+
+from art.tinker.cookbook_v import renderers
+from art.tinker.cookbook_v.tokenizer_utils import Tokenizer
+from art.tinker.renderers import get_renderer_name
+from art.tinker_native.data import convert_openai_messages_to_renderer_format
+
+
+class FakeTokenizer:
+ name_or_path = "fake/qwen3_5"
+
+ _SPECIAL_TOKENS = ("<|im_end|>", "")
+
+ def __init__(self) -> None:
+ self._text_to_id: dict[str, int] = {}
+ self._id_to_text: dict[int, str] = {}
+ self._next_id = 100
+ for idx, token in enumerate(self._SPECIAL_TOKENS, start=1):
+ self._text_to_id[token] = idx
+ self._id_to_text[idx] = token
+
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
+ del add_special_tokens
+ tokens: list[int] = []
+ idx = 0
+ while idx < len(text):
+ matched = False
+ for special in self._SPECIAL_TOKENS:
+ if text.startswith(special, idx):
+ tokens.append(self._text_to_id[special])
+ idx += len(special)
+ matched = True
+ break
+ if matched:
+ continue
+
+ char = text[idx]
+ if char not in self._text_to_id:
+ self._text_to_id[char] = self._next_id
+ self._id_to_text[self._next_id] = char
+ self._next_id += 1
+ tokens.append(self._text_to_id[char])
+ idx += 1
+ return tokens
+
+ def decode(self, tokens: int | list[int]) -> str:
+ if isinstance(tokens, int):
+ return self._id_to_text[tokens]
+ return "".join(self._id_to_text[token] for token in tokens)
+
+
+def _decode_model_input(tokenizer: FakeTokenizer, model_input: object) -> str:
+ tokens: list[int] = []
+ for chunk in model_input.chunks: # type: ignore[attr-defined]
+ assert hasattr(chunk, "tokens"), f"Unexpected non-text chunk: {chunk!r}"
+ tokens.extend(list(chunk.tokens))
+ return tokenizer.decode(tokens)
+
+
+def _get_test_renderer(name: str, tokenizer: FakeTokenizer) -> renderers.Renderer:
+ return renderers.get_renderer(name, cast(Tokenizer, tokenizer))
+
+
+def test_get_renderer_name_autodetects_qwen3_5() -> None:
+ assert get_renderer_name("Qwen/Qwen3.5-35B-A3B") == "qwen3_5"
+
+
+def test_qwen3_5_generation_prompt_matches_hf_suffixes() -> None:
+ tokenizer = FakeTokenizer()
+
+ renderer = _get_test_renderer("qwen3_5", tokenizer)
+ prompt = renderer.build_generation_prompt(
+ [
+ {"role": "user", "content": "Question"},
+ {"role": "assistant", "content": "Interim answer"},
+ ]
+ )
+ rendered = _decode_model_input(tokenizer, prompt)
+ assert (
+ "<|im_start|>assistant\n\n\n\n\nInterim answer<|im_end|>"
+ in rendered
+ )
+ assert rendered.endswith("<|im_start|>assistant\n\n")
+
+ disable_renderer = _get_test_renderer("qwen3_5_disable_thinking", tokenizer)
+ disable_prompt = disable_renderer.build_generation_prompt([])
+ disable_rendered = _decode_model_input(tokenizer, disable_prompt)
+ assert disable_rendered == "<|im_start|>assistant\n\n\n\n\n"
+
+
+def test_qwen3_5_parse_response_handles_xml_tool_calls() -> None:
+ tokenizer = FakeTokenizer()
+ renderer = _get_test_renderer("qwen3_5", tokenizer)
+
+ response = tokenizer.encode(
+ " reasoning \n\nAnswer first.\n\n"
+ "\n"
+ "\n"
+ "\n"
+ "San Francisco\n"
+ "\n"
+ "\n"
+ "3\n"
+ "\n"
+ "\n"
+ ""
+ "<|im_end|>"
+ )
+
+ message, success = renderer.parse_response(response)
+
+ assert success is True
+ assert message["content"] == [
+ {"type": "thinking", "thinking": "reasoning"},
+ {"type": "text", "text": "Answer first.\n\n"},
+ ]
+ assert "unparsed_tool_calls" not in message
+ assert len(message["tool_calls"]) == 1
+ assert message["tool_calls"][0].function.name == "lookup_weather"
+ assert json.loads(message["tool_calls"][0].function.arguments) == {
+ "city": "San Francisco",
+ "days": 3,
+ }
+
+
+def test_qwen3_5_to_openai_message_uses_mapping_tool_arguments() -> None:
+ tokenizer = FakeTokenizer()
+ renderer = _get_test_renderer("qwen3_5", tokenizer)
+
+ message: renderers.Message = {
+ "role": "assistant",
+ "content": [
+ renderers.ThinkingPart(type="thinking", thinking="reason"),
+ renderers.TextPart(type="text", text="Answer"),
+ ],
+ "tool_calls": [
+ renderers.ToolCall(
+ function=renderers.ToolCall.FunctionBody(
+ name="lookup_weather",
+ arguments=json.dumps({"city": "San Francisco", "days": 3}),
+ )
+ )
+ ],
+ }
+
+ openai_message = renderer.to_openai_message(message)
+
+ assert openai_message["content"] == "Answer"
+ assert openai_message["reasoning_content"] == "reason"
+ assert openai_message["tool_calls"][0]["function"]["arguments"] == {
+ "city": "San Francisco",
+ "days": 3,
+ }
+
+
+def test_convert_openai_messages_to_renderer_format_stringifies_dict_arguments() -> (
+ None
+):
+ tokenizer = FakeTokenizer()
+ renderer = _get_test_renderer("qwen3_5", tokenizer)
+
+ converted = convert_openai_messages_to_renderer_format(
+ [
+ {
+ "role": "assistant",
+ "content": "Calling a tool",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "function": {
+ "name": "lookup_weather",
+ "arguments": {"city": "San Francisco", "days": 3},
+ },
+ }
+ ],
+ }
+ ],
+ tools=None,
+ renderer=renderer,
+ )
+
+ tool_call = converted[0]["tool_calls"][0]
+ assert tool_call.function.name == "lookup_weather"
+ assert json.loads(tool_call.function.arguments) == {
+ "city": "San Francisco",
+ "days": 3,
+ }
+
+
+def test_get_renderer_supports_kimi_k25_factory() -> None:
+ tokenizer = FakeTokenizer()
+
+ renderer = _get_test_renderer("kimi_k25", tokenizer)
+
+ assert renderer.__class__.__name__ == "KimiK25Renderer"
|