Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import torch
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.backend_validator import validate
from .base_att import BaseAttBackend
Expand All @@ -13,18 +13,24 @@
from .fa3.fp import Fa3AttBackend
from .fa3.fp8 import Fp8Fa3AttBackend
from .fa3.mla import MlaFa3AttBackend
from .paged_fa3.fp import PagedFa3AttBackend
from .paged_fa3.mla import PagedMlaFa3AttBackend
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .paged_flashinfer.fp import PagedFlashInferAttBackend
from .paged_flashinfer.mla import PagedMlaFlashInferAttBackend

logger = init_logger(__name__)

_PAGE_ENABLED = get_page_size() > 1

# Backend class mappings by data type
data_type_to_backend = {
"None": {
"triton": TritonAttBackend,
"fa3": Fa3AttBackend,
"flashinfer": FlashInferAttBackend,
"triton": TritonAttBackend, # triton后端已支持任意page size
"fa3": PagedFa3AttBackend if _PAGE_ENABLED else Fa3AttBackend,
"flashinfer": PagedFlashInferAttBackend if _PAGE_ENABLED else FlashInferAttBackend,
},
"int4kv": {
"triton": Int4kvTritonAttBackend,
Expand All @@ -41,8 +47,8 @@
mla_data_type_to_backend = {
"None": {
"triton": MlaTritonAttBackend,
"fa3": MlaFa3AttBackend,
"flashinfer": MlaFlashInferAttBackend,
"fa3": PagedMlaFa3AttBackend if _PAGE_ENABLED else MlaFa3AttBackend,
"flashinfer": PagedMlaFlashInferAttBackend if _PAGE_ENABLED else MlaFlashInferAttBackend,
},
}

Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ def prefill_att(
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.use_alibi is False
return self._nomarl_prefill_att(
return self._normal_prefill_att(
q=q,
k=k,
v=v,
att_control=att_control,
alloc_func=alloc_func,
)

def _nomarl_prefill_att(
def _normal_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
) -> torch.Tensor:
self.backend: Fa3AttBackend = self.backend # for typing
Expand Down
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/flashinfer/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def prefill_att(
and att_control.use_sliding_window is False
and att_control.use_att_sink is False
)
return self._nomarl_prefill_att(
return self._normal_prefill_att(
q=q,
k=k,
v=v,
alloc_func=alloc_func,
)

def _nomarl_prefill_att(
def _normal_prefill_att(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty
) -> torch.Tensor:
self.backend: FlashInferAttBackend = self.backend # for typing
Expand Down
Empty file.
188 changes: 188 additions & 0 deletions lightllm/common/basemodel/attention/paged_fa3/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import dataclasses
import torch
import triton
from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor


class PagedFa3AttBackend(BaseAttBackend):
def __init__(self, model, page_size=None):
super().__init__(model=model)
self.page_size = page_size or get_page_size()
self.get_page_table_buffer()

def get_page_table_buffer(self):
model = self.model
if not hasattr(self, "_shared_page_table_buffer"):
shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size)
self._shared_page_table_buffer = [
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()),
]
return self._shared_page_table_buffer

def create_att_prefill_state(self, infer_state):
return PagedFa3PrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state):
return PagedFa3DecodeAttState(backend=self, infer_state=infer_state)


@dataclasses.dataclass
class PagedFa3PrefillAttState(BasePrefillAttState):
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
page_table: torch.Tensor = None

def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()
table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size)
self.page_table = torch.empty(
(self.infer_state.batch_size, table_len),
dtype=torch.int32,
device=self.infer_state.input_ids.device,
)
page_table_copy(
page_table=self.page_table,
req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs,
b_req_idx=self.infer_state.b_req_idx,
)

def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty):
assert att_control.use_alibi is False
return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)

def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty):
if att_control.use_sliding_window:
window_size = att_control.sliding_window
else:
window_size = (-1, -1)

if att_control.use_att_sink:
sink_weight = att_control.sink_weight
else:
sink_weight = None

sm_scale = 1.0 / (q.shape[-1] ** 0.5)
return flash_attn_with_kvcache(
q=q,
k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.infer_state.max_q_seq_len,
softmax_scale=sm_scale,
causal=True,
window_size=window_size,
softcap=0.0,
k_descale=None,
v_descale=None,
return_softmax_lse=False,
sinks=sink_weight,
)


@dataclasses.dataclass
class PagedFa3DecodeAttState(BaseDecodeAttState):
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
page_table: torch.Tensor = None
b_att_seq_len: torch.Tensor = None
decode_max_q_seq_len: int = None

def init_state(self):
args_mtp_step = get_env_start_args().mtp_step
if args_mtp_step > 0:
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
(self.infer_state.b_seq_len.shape[0] // mtp_size,),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The integer division self.infer_state.b_seq_len.shape[0] // mtp_size might lead to incorrect sizing if self.infer_state.b_seq_len.shape[0] is not perfectly divisible by mtp_size. This could result in an array of an unexpected size. Consider using triton.cdiv or adding an assertion to ensure exact divisibility if that's the expectation.

fill_value=mtp_size,
dtype=torch.int32,
device=self.infer_state.b_seq_len.device,
)
b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size]
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
self.cu_seqlens_q = b1_cu_q_seq_len.int()
self.cu_seqlens_k = b1_cu_kv_seq_len.int()
else:
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to line 105, the integer division self.infer_state.batch_size // (args_mtp_step + 1) could cause issues if self.infer_state.batch_size is not perfectly divisible by args_mtp_step + 1. This might lead to an incorrect att_batch_size which could propagate errors in subsequent operations like reshaping or slicing.

assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
model = self.backend.model
table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size)
if (
self.infer_state.batch_size <= model.graph_max_batch_size
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
):
page_buffer = self.backend.get_page_table_buffer()
shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size)
self.page_table = page_buffer[self.infer_state.microbatch_index][
: att_batch_size * shared_table_len
].reshape(att_batch_size, shared_table_len)
Comment on lines +128 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing page_buffer[self.infer_state.microbatch_index][: att_batch_size * shared_table_len] assumes that att_batch_size * shared_table_len is always less than or equal to the actual length of the buffer for the given microbatch_index. If att_batch_size is miscalculated due to integer division (as noted in line 118), this could lead to incorrect slicing or out-of-bounds access. It's safer to ensure the calculated length is within bounds or use min to prevent over-indexing.

else:
self.page_table = torch.empty(
(att_batch_size, table_len),
dtype=torch.int32,
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
page_table_copy(
page_table=self.page_table[:, :table_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The slicing self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)] implies a very specific structure and length for b_req_idx. If b_req_idx does not conform to this pattern (e.g., its length is not a multiple of args_mtp_step + 1), this slice might produce an unexpected result or an empty tensor, leading to incorrect page table copies.

)
self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous()
self.decode_max_q_seq_len = args_mtp_step + 1
else:
page_table_copy(
page_table=self.page_table[:, :table_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
b_req_idx=self.infer_state.b_req_idx,
)
self.b_att_seq_len = self.infer_state.b_seq_len
self.decode_max_q_seq_len = 1

def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty):
assert att_control.use_alibi is False
return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)

def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty):
if att_control.use_sliding_window:
window_size = att_control.sliding_window
else:
window_size = (-1, -1)

if att_control.use_att_sink:
sink_weight = att_control.sink_weight
else:
sink_weight = None

sm_scale = 1.0 / (q.shape[-1] ** 0.5)
return flash_attn_with_kvcache(
q=q,
k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]),
v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]),
page_table=self.page_table,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
softmax_scale=sm_scale,
causal=True,
window_size=window_size,
softcap=0.0,
k_descale=None,
v_descale=None,
return_softmax_lse=False,
sinks=sink_weight,
)
Loading