diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13a..d055aaf6c9 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -2,7 +2,7 @@ import os import torch -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate from .base_att import BaseAttBackend @@ -13,18 +13,24 @@ from .fa3.fp import Fa3AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend +from .paged_fa3.fp import PagedFa3AttBackend +from .paged_fa3.mla import PagedMlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend +from .paged_flashinfer.fp import PagedFlashInferAttBackend +from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend logger = init_logger(__name__) +_PAGE_ENABLED = get_page_size() > 1 + # Backend class mappings by data type data_type_to_backend = { "None": { - "triton": TritonAttBackend, - "fa3": Fa3AttBackend, - "flashinfer": FlashInferAttBackend, + "triton": TritonAttBackend, # triton后端已支持任意page size + "fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend, + "flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend, }, "int4kv": { "triton": Int4kvTritonAttBackend, @@ -41,8 +47,8 @@ mla_data_type_to_backend = { "None": { "triton": MlaTritonAttBackend, - "fa3": MlaFa3AttBackend, - "flashinfer": MlaFlashInferAttBackend, + "fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend, + "flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend, }, } diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..9568e4a892 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -66,7 +66,7 @@ def prefill_att( alloc_func=torch.empty, ) -> torch.Tensor: assert att_control.use_alibi is False - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, @@ -74,7 +74,7 @@ def prefill_att( alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..37478be76f 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -99,14 +99,14 @@ def prefill_att( and att_control.use_sliding_window is False and att_control.use_att_sink is False ) - return self._nomarl_prefill_att( + return self._normal_prefill_att( q=q, k=k, v=v, alloc_func=alloc_func, ) - def _nomarl_prefill_att( + def _normal_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing diff --git a/lightllm/common/basemodel/attention/paged_fa3/__init__.py b/lightllm/common/basemodel/attention/paged_fa3/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_fa3/fp.py b/lightllm/common/basemodel/attention/paged_fa3/fp.py new file mode 100644 index 0000000000..5c01538c42 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/fp.py @@ -0,0 +1,188 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + self.page_table = torch.empty( + (self.infer_state.batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + page_table_copy( + page_table=self.page_table, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) + + +@dataclasses.dataclass +class PagedFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) diff --git a/lightllm/common/basemodel/attention/paged_fa3/mla.py b/lightllm/common/basemodel/attention/paged_fa3/mla.py new file mode 100644 index 0000000000..2e33c05409 --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/mla.py @@ -0,0 +1,174 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache, flash_attn_varlen_func +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor + + +class PagedMlaFa3AttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def create_att_prefill_state(self, infer_state): + return PagedMlaFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + return self._mla_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl, alloc_func=torch.empty + ): + k_nope, k_rope = k + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 + assert att_control.mla_prefill + return flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + max_seqlen_k=self.infer_state.max_kv_seq_len, + softmax_scale=att_control.mla_prefill_dict["softmax_scale"], + causal=True, + return_softmax_lse=False, + ) + + +@dataclasses.dataclass +class PagedMlaFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + return self._mla_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _mla_decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl, alloc_func=torch.empty + ): + q_nope, q_rope = q + qk_rope_head_dim = 64 + kv_lora_rank = k.shape[-1] - qk_rope_head_dim + return flash_attn_with_kvcache( + q=q_rope, + k_cache=k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + v_cache=k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, kv_lora_rank), + qv=q_nope, + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=att_control.mla_decode_dict["softmax_scale"], + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + ) diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py b/lightllm/common/basemodel/attention/paged_flashinfer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/fp.py b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py new file mode 100644 index 0000000000..b1807ca30b --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/fp.py @@ -0,0 +1,193 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + tp_world_size = get_dp_world_size() + self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size + self.tp_kv_head_num = max(model.config["num_key_value_heads"] // tp_world_size, 1) + head_dim = model.config["hidden_size"] // model.config["num_attention_heads"] + self.head_dim = model.config.get("head_dim", head_dim) + self.workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + + def create_att_prefill_state(self, infer_state): + return PagedFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedFlashInferAttBackend = self.backend + import flashinfer + + batch_size = self.infer_state.batch_size + device = self.infer_state.input_ids.device + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + kv_indices = torch.empty( + batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size), + dtype=torch.int32, + device=device, + ) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + kv_indices, + ) + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + qo_indptr_buf=q_starts, + paged_kv_indptr_buf=kv_starts, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len, + ) + self.prefill_wrapper.plan( + q_starts, + kv_starts, + kv_indices, + kv_last_page_len, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + ) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + self.prefill_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor + + +@dataclasses.dataclass +class PagedFlashInferDecodeAttState(BaseDecodeAttState): + kv_last_page_len_buffer: torch.Tensor = None + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedFlashInferAttBackend = self.backend + device = self.infer_state.input_ids.device + model = self.backend.model + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_last_page_len_buffer = self.infer_state.b_seq_len - (b_page_len - 1) * self.backend.page_size + buffer_len = self.infer_state.batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + self.kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + self.backend.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=self.kv_starts, + paged_kv_indices_buffer=self.kv_indices, + paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + ) + self.decode_wrapper.plan( + self.kv_starts, + self.kv_indices, + self.kv_last_page_len_buffer, + self.backend.tp_q_head_num, + self.backend.tp_kv_head_num, + self.backend.head_dim, + self.backend.page_size, + q_data_type=self.backend.q_data_type, + kv_data_type=self.backend.kv_data_type, + non_blocking=True, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.kv_starts, + new_state.kv_indices, + new_state.kv_last_page_len_buffer, + new_state.backend.tp_q_head_num, + new_state.backend.tp_kv_head_num, + new_state.backend.head_dim, + new_state.backend.page_size, + q_data_type=new_state.backend.q_data_type, + kv_data_type=new_state.backend.kv_data_type, + non_blocking=True, + ) + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + o_tensor = alloc_func(q.shape, q.dtype) + self.decode_wrapper.run( + q, + ( + k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + ), + out=o_tensor, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/paged_flashinfer/mla.py b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py new file mode 100644 index 0000000000..c9ea38052f --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_flashinfer/mla.py @@ -0,0 +1,184 @@ +import dataclasses +import torch +import triton +from typing import Tuple +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id +from ...triton_kernel.repack_kv_index import paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size +from ..flashinfer.env_utils import set_flashinfer_envs + + +class PagedMlaFlashInferAttBackend(BaseAttBackend): + def __init__(self, model, page_size=None): + set_flashinfer_envs() + super().__init__(model=model) + self.page_size = page_size or get_page_size() + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.v_head_dim = model.v_head_dim + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + buffer_len = model.graph_max_batch_size * triton.cdiv(self.max_seq_length, self.page_size) + self.kv_indices_buffer = [ + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + torch.empty(buffer_len, dtype=torch.int32, device=get_current_device_id()), + ] + + from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale + + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def create_att_prefill_state(self, infer_state): + return PagedMlaFlashInferPrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedMlaFlashInferDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedMlaFlashInferPrefillAttState(BasePrefillAttState): + prefill_wrapper: object = None + + def init_state(self): + self.backend: PagedMlaFlashInferAttBackend = self.backend + import flashinfer + + q_starts = self.infer_state.b1_cu_q_seq_len.int() + kv_starts = self.infer_state.b1_cu_kv_seq_len.int() + if self.prefill_wrapper is None: + self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + self.backend.workspace_buffer, "NHD" + ) + self.prefill_wrapper.plan( + qo_indptr=q_starts, + kv_indptr=kv_starts, + num_qo_heads=self.backend.tp_q_head_num, + num_kv_heads=self.backend.tp_q_head_num, + head_dim_qk=self.backend.qk_nope_head_dim + self.backend.qk_rope_head_dim, + head_dim_vo=self.backend.v_head_dim, + q_data_type=self.backend.q_data_type, + causal=True, + sm_scale=self.backend.softmax_scale, + ) + + def prefill_att( + self, q, k: Tuple[torch.Tensor, torch.Tensor], v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + k_nope, k_rope = k + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") + q_head_num = q.shape[1] + k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) + self.prefill_wrapper.run(q, k, v, out=o_tensor) + return o_tensor + + +@dataclasses.dataclass +class PagedMlaFlashInferDecodeAttState(BaseDecodeAttState): + kv_indices: torch.Tensor = None + kv_starts: torch.Tensor = None + decode_wrapper: object = None + + def init_state(self): + import flashinfer + + self.backend: PagedMlaFlashInferAttBackend = self.backend + model = self.backend.model + device = self.infer_state.input_ids.device + batch_size = self.infer_state.batch_size + self.kv_starts = self.infer_state.b1_cu_kv_seq_len + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + buffer_len = batch_size * triton.cdiv(self.backend.max_seq_length, self.backend.page_size) + if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: + self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][:buffer_len] + else: + self.kv_indices = torch.empty(buffer_len, dtype=torch.int32, device=device) + + b_page_len = triton.cdiv(self.infer_state.b_seq_len, self.backend.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) + paged_repack_kv_index( + self.infer_state.req_manager.req_to_token_indexs, + self.infer_state.b_req_idx, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size), + self.kv_indices, + ) + self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + self.backend.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.q_indptr, + kv_indices=self.kv_indices, + kv_indptr=self.kv_starts, + kv_len_arr=self.infer_state.b_seq_len, + ) + self.decode_wrapper.plan( + self.q_indptr, + self.kv_starts, + self.kv_indices, + self.infer_state.b_seq_len, + self.backend.tp_q_head_num, + self.backend.kv_lora_rank, + self.backend.qk_rope_head_dim, + self.backend.page_size, + False, + self.backend.softmax_scale, + self.backend.q_data_type, + self.backend.kv_data_type, + ) + + def copy_for_decode_cuda_graph(self, new_state): + super().copy_for_decode_cuda_graph(new_state) + self.decode_wrapper.plan( + new_state.q_indptr, + new_state.kv_starts, + new_state.kv_indices, + new_state.infer_state.b_seq_len, + new_state.backend.tp_q_head_num, + new_state.backend.kv_lora_rank, + new_state.backend.qk_rope_head_dim, + new_state.backend.page_size, + False, + new_state.backend.softmax_scale, + new_state.backend.q_data_type, + new_state.backend.kv_data_type, + ) + + def decode_att( + self, q: Tuple[torch.Tensor, torch.Tensor], k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty + ): + assert ( + att_control.use_alibi is False + and att_control.use_sliding_window is False + and att_control.use_att_sink is False + ) + assert v is None + q_nope, q_rope = q + qk_rope_head_dim = 64 + o_tensor = alloc_func(q_nope.shape, dtype=q_nope.dtype, device=q_nope.device) + self.decode_wrapper.run( + q_nope, + q_rope, + k[:, :, :-qk_rope_head_dim].view(-1, self.backend.page_size, 1, k.shape[-1] - qk_rope_head_dim), + k[:, :, -qk_rope_head_dim:].view(-1, self.backend.page_size, 1, qk_rope_head_dim), + out=o_tensor, + return_lse=False, + ) + return o_tensor diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..23bf245af5 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -30,7 +30,7 @@ def prefill_att( assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._normal_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,7 +59,7 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + def _normal_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd out = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/triton_kernel/fa3_utils.py b/lightllm/common/basemodel/triton_kernel/fa3_utils.py index 0a524b63b6..f9d1c9e9c6 100644 --- a/lightllm/common/basemodel/triton_kernel/fa3_utils.py +++ b/lightllm/common/basemodel/triton_kernel/fa3_utils.py @@ -1,5 +1,6 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -37,6 +38,13 @@ def page_table_copy( assert page_table.dim() == 2, "page_table should be 2D" assert req_to_token_indexs.dim() == 2, "req_to_token_indexs should be 2D" + page_size = get_page_size() + if page_size > 1: + max_seq_len_k = page_table.shape[1] * page_size + sampled = req_to_token_indexs[b_req_idx, :max_seq_len_k:page_size] + page_table.copy_(sampled // page_size) + return + max_seq_len_k = page_table.shape[1] batch_size = page_table.size(0) BLOCK_SIZE = 128 diff --git a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py index e86d2e819e..d50a0a230b 100644 --- a/lightllm/common/basemodel/triton_kernel/repack_kv_index.py +++ b/lightllm/common/basemodel/triton_kernel/repack_kv_index.py @@ -2,6 +2,7 @@ import triton import triton.language as tl +from lightllm.utils.envs_utils import get_page_size @triton.jit @@ -33,6 +34,40 @@ def _fwd_kernel_repack_kv_index( return +@triton.jit +def _fwd_kernel_repack_page_kv_index_from_tokens( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + token_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = (start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) * page_size + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) * page_size + token_data = tl.load( + req_to_token_indexs + token_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + page_data = token_data // page_size + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, page_data, mask=offs_seq < block_end_loc) + return + + @torch.no_grad() def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): batch_size = req_index.shape[0] @@ -58,6 +93,34 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv return +@torch.no_grad() +def paged_repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): + page_size = get_page_size() + assert page_size > 1 + batch_size = req_index.shape[0] + # flashinfer requires out_kv_index to be zeroed before use + out_kv_index.zero_() + BLOCK = 64 + grid = ( + batch_size, + triton.cdiv(max_seq_len, BLOCK), + ) + + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + kv_index, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + kv_index.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + return + + def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): output[start : start + sl] = req_to_token_indexs[b][:sl] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3d93e1b070..f5eccf2670 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -1,6 +1,7 @@ import torch import os import torch.distributed as dist +import triton from lightllm.server.pd_io_struct import KVMoveTask from .mem_manager import MemoryManager from typing import List, Union, Any @@ -9,6 +10,7 @@ from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -45,7 +47,9 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..4988468b34 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +import triton from typing import List, Union, Tuple, Any from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask @@ -11,7 +12,7 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads @@ -35,6 +36,9 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.dtype = dtype # profile the max total token num if the size is None self.profile_size(mem_fraction) + page_size = get_page_size() + if page_size > 1: + self.size = (self.size // page_size) * page_size self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True @@ -108,7 +112,9 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, 2 * head_num, head_dim), dtype=dtype, device="cuda") def alloc_kv_move_buffer(self, max_req_total_len): """ diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 33bdca4475..93364f263e 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,11 +1,12 @@ import torch import collections +import triton from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager @@ -71,13 +72,32 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana def alloc(self): return self.req_list.alloc() + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + def alloc_mem_indices( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + return self.mem_manager.alloc(need_size) + def free(self, free_req_indexes: List[int], free_token_index): for req_index in free_req_indexes: self.req_list.free(req_index) if self.req_list.is_all_free(): logger.debug(f"freed all request size {self.req_list.can_alloc_size}") - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) def free_req(self, free_req_index: int): self.req_list.free(free_req_index) @@ -86,13 +106,73 @@ def free_req(self, free_req_index: int): return def free_token(self, free_token_index): - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) return def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def _expand_to_page_mem_indices(self, free_token_index): + page_size = get_page_size() + if page_size > 1: + if isinstance(free_token_index, list): + free_token_index = torch.tensor(free_token_index, dtype=torch.int32) + base_indices = free_token_index[free_token_index % page_size == 0] + if len(base_indices) == 0: + return free_token_index + page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device) + return (base_indices[:, None] + page_offsets[None, :]).reshape(-1) + + return free_token_index + + def _expand_by_page_size(self, b_token_len, page_size): + b_page_len = triton.cdiv(b_token_len, page_size) + need_pages_num = int(b_page_len.sum().item()) + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, p_token_len + + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() + if b_ready_cache_len is not None: + b_ready_cache_len = b_ready_cache_len.cpu() + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] + + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = int(need_new_page_mask.sum().item()) + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size) + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] + mask = ~need_new_page_mask + if mask.any(): + token_idxs[mask] = b_last_mem_index[mask] + 1 + return token_idxs + + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = triton.cdiv(need_tokens_array, page_size) + need_new_pages = need_pages_array.sum() + else: + need_new_pages = ((b_seq_len - 1) % page_size == 0).sum() + return need_new_pages * page_size + class ReqSamplingParamsManager: """ diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 111def60c2..3671186e41 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -10,7 +10,7 @@ from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name -from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive +from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive, get_page_size from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active @@ -134,10 +134,26 @@ def normal_or_p_d_start(args): if args.mtp_mode is not None: assert args.mtp_draft_model_dir is not None assert args.mtp_step > 0 + assert get_page_size() == 1, "page_size > 1 is not supported with MTP, please set PAGE_SIZE=1" else: assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # page_size > 1 compatibility check + if get_page_size() > 1: + assert args.run_mode not in ( + "prefill", + "decode", + ), "page_size > 1 is not supported with RPyC PD split mode, please set PAGE_SIZE=1" + assert args.run_mode not in ( + "nixl_prefill", + "nixl_decode", + ), "page_size > 1 is not supported with NIXL PD split mode, please set PAGE_SIZE=1" + assert ( + not args.enable_dp_prefill_balance + ), "page_size > 1 is not supported with DP prefill balance, please set PAGE_SIZE=1" + assert not args.enable_cpu_cache, "page_size > 1 is not supported with CPU cache, please set PAGE_SIZE=1" + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 0000000000..72fab6e274 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,538 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +import collections +from typing import Tuple, Dict, Set, List, Optional, Union +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) + + def _compute_key(self, tokens: torch.Tensor): + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else page_tokens.cpu().numpy().tobytes() + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + split_parent_node.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + split_parent_node.children[self._compute_key(self.token_id_key[prefix_len:])] = self + split_parent_node.ref_counter = self.ref_counter + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = self.token_id_key[prefix_len:] + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child_key = child._compute_key(child.token_id_key) + assert child_key not in self.children.keys() + self.children[child_key] = child + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[child_node._compute_key(child_node.token_id_key)] + child_node.parent = None + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.common.kv_cache_mem_manager import MemoryManager + + self.mem_manager: MemoryManager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _align_prefix_len(self, prefix_len: int) -> int: + if self.page_size <= 1: + return prefix_len + if prefix_len % self.page_size == 0: + return prefix_len + if self._page_size_is_power_of_2: + return prefix_len & ~self._page_size_mask + return (prefix_len // self.page_size) * self.page_size + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + if self.page_size > 1 and aligned_len % self.page_size != 0: + aligned_len = self._align_prefix_len(aligned_len) + if free_truncated and aligned_len < len(key) and self.mem_manager is not None and value is not None: + truncated_value = value[aligned_len:] + if len(truncated_value) > 0: + base = truncated_value[0] - truncated_value[0] % self.page_size + full_page = torch.arange( + base, base + self.page_size, dtype=truncated_value.dtype, device=truncated_value.device + ) + self.mem_manager.free(full_page) + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + if value is None: + value = key + + assert len(key) == len(value) + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0, None + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key, value)) + + ans_prefix_len = 0 + ans_node = None + + while len(handle_stack) != 0: + node, key, value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) + if len(ans_tuple) == 4: + (_prefix_len, new_node, new_key, new_value) = ans_tuple + ans_prefix_len += _prefix_len + handle_stack.append((new_node, new_key, new_value)) + else: + _prefix_len, ans_node = ans_tuple + ans_prefix_len += _prefix_len + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + assert ans_node is not None + + return ans_prefix_len, ans_node + + def _insert_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + child_key = node._compute_key(key) + if child_key in node.children.keys(): + child: TreeNode = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + if prefix_len == len(key): + if prefix_len == len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, child + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if child.is_leaf(): + self.evict_tree_set.add(child) + + return prefix_len, split_parent_node + else: + assert False, "can not run to here" + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, new_node + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + if update_refs: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key)) + + ans_node = None + + while len(handle_stack) != 0: + node, key = handle_stack.popleft() + ans_tuple = self._match_prefix_helper_no_recursion( + node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs + ) + if isinstance(ans_tuple, tuple): + new_node, new_key = ans_tuple + handle_stack.append((new_node, new_key)) + else: + ans_node = ans_tuple + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + return ans_node + + def _match_prefix_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + if len(key) == 0: + return node + + child_key = node._compute_key(key) + if child_key not in node.children.keys(): + return node + else: + child = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + return node + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return (child, key[prefix_len:]) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: + parent_node = child_node.parent + if ( + parent_node is None + or parent_node == self.root_node + or parent_node.ref_counter != 0 + or len(parent_node.children) != 1 + or child_node.ref_counter != 0 + ): + return None + + if child_node.is_leaf(): + self.evict_tree_set.discard(child_node) + + child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key]) + child_node.token_mem_index_value = torch.cat( + [parent_node.token_mem_index_value, child_node.token_mem_index_value] + ) + child_node.node_value_len = len(child_node.token_mem_index_value) + child_node.time_id = max(parent_node.time_id, child_node.time_id) + + grandparent_node = parent_node.parent + key_in_grandparent = grandparent_node._compute_key(parent_node.token_id_key) + grandparent_node.children[key_in_grandparent] = child_node + child_node.parent = grandparent_node + + parent_node.parent = None + + if child_node.is_leaf(): + self.evict_tree_set.add(child_node) + + return child_node + + def merge_unreferenced_nodes(self): + worklist = collections.deque( + [ + node + for node in self.evict_tree_set + if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node + ] + ) + + while worklist: + node = worklist.popleft() + if node.parent is None: + continue + merged_node = self._try_merge(node) + if merged_node: + worklist.append(merged_node) + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..8fd8a13d6b 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -329,6 +329,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False @@ -415,6 +416,7 @@ def _match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 self.shm_req.shm_cur_kv_len = self.cur_kv_len diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..02c71e9624 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -28,6 +29,7 @@ from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node from lightllm.utils.envs_utils import ( get_env_start_args, + get_page_size, enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) @@ -172,8 +174,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = PagedRadixCache if get_page_size() > 1 else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 03ac4cfb05..f4fa8ae54c 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -86,8 +86,18 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() if padded_req_num > 0: @@ -140,6 +150,7 @@ def padded_prepare_decode_inputs( b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] args_mtp_step = get_env_start_args().mtp_step batch_multimodal_params = [] for req in req_objs: @@ -152,6 +163,7 @@ def padded_prepare_decode_inputs( total_token_num += seq_len b_mtp_index.append(0) batch_multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -187,13 +199,23 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + b_seq_len.shape[0] - padded_mem_indexes_num, b_seq_len[: len(b_last_mem_index)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0] - padded_mem_indexes_num, + b_seq_len[: len(b_last_mem_index)], + b_last_mem_index=b_last_mem_index, + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() if padded_mem_indexes_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 4eb8c7e1e6..a915564d78 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -64,8 +64,16 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0], b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -97,6 +105,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] multimodal_params = [] for req in req_objs: run_reqs.append(req) @@ -108,6 +117,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -125,6 +135,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") if enable_diverse_mode_gqa_decode_fast_kernel(): b_shared_seq_len, b_mark_shared_group = build_diverse_shared_group_infos(run_reqs=run_reqs) @@ -135,8 +146,13 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 0d870b55d8..ee49c3e1a2 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size class ChunkedPrefillQueue(BaseQueue): @@ -33,9 +34,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = len(self.cache_len_list) * (page_size - 1) if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7a7a9be121..a8ef5ee497 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -161,6 +161,11 @@ def get_triton_autotune_level(): return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0)) +@lru_cache(maxsize=None) +def get_page_size(): + return int(os.getenv("PAGE_SIZE", 1)) + + g_model_init_done = False diff --git a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py index b5184d3caa..0bab0ae540 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py +++ b/unit_tests/common/basemodel/triton_kernel/test_repack_kv_index.py @@ -1,7 +1,8 @@ import torch import pytest from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index +from lightllm.common.basemodel.triton_kernel.repack_kv_index import repack_kv_index, paged_repack_kv_index +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -41,3 +42,49 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) assert torch.allclose(output.float(), ref.float()) + + +@pytest.mark.parametrize( + "batch, max_seq_len, page_size", + [ + (1, 16, 4), + (8, 32, 4), + (16, 128, 8), + ], +) +def test_paged_repack_kv_index(batch, max_seq_len, page_size, monkeypatch): + def repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, output, page_size): + for b, sl, start in zip(b_req_idx, b_page_len, b_start_loc): + output[start : start + sl] = req_to_token_indexs[b][: sl * page_size : page_size] // page_size + + BATCH, MAX_SEQ_LEN = batch, max_seq_len + max_page_len = (MAX_SEQ_LEN + page_size - 1) // page_size + total_token_len = 2 * MAX_SEQ_LEN + total_page_len = (total_token_len + page_size - 1) // page_size + + req_to_token_indexs = torch.empty((2 * BATCH, total_token_len), dtype=torch.int32, device="cuda") + page_offsets = torch.arange(page_size, dtype=torch.int32, device="cuda") + for row in range(2 * BATCH): + page_ids = torch.arange(row * total_page_len, (row + 1) * total_page_len, dtype=torch.int32, device="cuda") + req_to_token_indexs[row] = (page_ids[:, None] * page_size + page_offsets[None, :]).reshape(-1)[:total_token_len] + + b_req_idx = torch.randperm(BATCH, device="cuda", dtype=torch.int32) + b_seq_len = torch.randint(1, MAX_SEQ_LEN + 1, (BATCH,), device="cuda", dtype=torch.int32) + b_page_len = (b_seq_len + page_size - 1) // page_size + b_start_loc = torch.cat( + [torch.zeros((1,), dtype=torch.int32, device="cuda"), b_page_len[:-1].cumsum(dim=0, dtype=torch.int32)] + ) + + output = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + ref = torch.zeros((b_page_len.sum(),), dtype=torch.int32, device="cuda") + + monkeypatch.setenv("PAGE_SIZE", str(page_size)) + get_page_size.cache_clear() + try: + repack_page_kv_ref(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, ref, page_size) + paged_repack_kv_index(req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_page_len, output) + finally: + monkeypatch.delenv("PAGE_SIZE", raising=False) + get_page_size.cache_clear() + + assert torch.equal(output, ref)