Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,6 @@ def forward_mixed(

sliding_window = layer.sliding_window

norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None

if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
Expand Down Expand Up @@ -344,8 +340,8 @@ def forward_mixed(
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
Expand Down Expand Up @@ -400,8 +396,8 @@ def forward_mixed(
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
Expand Down
3 changes: 0 additions & 3 deletions fastdeploy/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
linear_smooth: paddle.Tensor = None,
use_neox_rotary_style: bool = False,
use_qk_norm: bool = False,
qk_norm_before_rope: bool = False,
rms_norm_eps: float = 1e-6,
with_sinks: bool = False,
) -> None:
Expand All @@ -77,7 +76,6 @@ def __init__(
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
qk_norm_before_rope (bool, optional): Whether to apply rmsnorm before rope (e.g., Qwen style). Defaults to False. if True, use_qk_norm should also be True.
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.

Raises:
Expand Down Expand Up @@ -126,7 +124,6 @@ def __init__(
else:
logger.info(f"Attention is running in cache kv {self.quant_method.cache_quant_config.quant_type} mode")
self.use_qk_norm = use_qk_norm
self.qk_norm_before_rope = qk_norm_before_rope
self.rms_norm_eps = rms_norm_eps
if self.use_qk_norm:
self.q_norm_key = f"{self.prefix}.q_norm"
Expand Down
50 changes: 4 additions & 46 deletions fastdeploy/model_executor/layers/attention/flash_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,48 +249,6 @@ def forward_mixed(
layer.layer_id + self.start_layer_index,
)

norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None

if layer.layer_id == 0:
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
)

if forward_meta.max_len_tensor_cpu[1].item() > 0:
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)

use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0

if use_fa_do_prefill:
Expand All @@ -312,8 +270,8 @@ def forward_mixed(
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
Expand Down Expand Up @@ -375,8 +333,8 @@ def forward_mixed(
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,6 @@ def forward_mixed(
):
metadata = forward_meta.attention_metadata

norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False)
q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None
k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None

if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
Expand Down Expand Up @@ -256,8 +252,8 @@ def forward_mixed(
forward_meta.pre_cache_batch_ids,
forward_meta.pre_cache_tile_ids_per_batch,
forward_meta.pre_cache_num_blocks_cpu,
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
Expand Down Expand Up @@ -324,8 +320,8 @@ def forward_mixed(
layer.linear_smooth,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
q_norm_weight,
k_norm_weight,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "sinks", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype,
Expand Down
89 changes: 1 addition & 88 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import paddle
from paddle import nn

from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform

if current_platform.is_gcu():
Expand All @@ -29,7 +28,7 @@
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.ops.triton_ops import _TRITON_AVAILABLE, qk_rmsnorm_fused
from fastdeploy.model_executor.forward_meta import ForwardMeta

from .utils import get_tensor

Expand Down Expand Up @@ -257,92 +256,6 @@ def forward(
return out, residual_out


class QKRMSNorm(nn.Layer):
"""
QK Normalization layer.
"""

def __init__(
self,
fd_config: FDConfig,
head_dim: int,
q_size: int,
kv_size: int,
eps: float = 1e-5,
prefix: str = "",
begin_norm_axis: int = 1,
dtype: str = None,
) -> None:
super().__init__()
self.fd_config = fd_config
self.prefix: str = prefix
self.head_dim: int = head_dim
self.q_weight_key: Optional[str] = f"{prefix}.q_norm.weight"
self.k_weight_key: Optional[str] = f"{prefix}.k_norm.weight"
self.eps: float = eps
self._norm_weight_dtype = dtype
if self._norm_weight_dtype is None:
self._norm_weight_dtype = self._helper.get_default_dtype()
else:
assert dtype in [
"float32",
"bfloat16",
"float16",
], f"Unsupported dtype: {dtype}. Must be one of: float32, bfloat16, float16"

self.q_size = q_size
self.kv_size = kv_size

self.q_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.q_norm",
begin_norm_axis=begin_norm_axis,
)
self.k_norm = RMSNorm(
fd_config,
hidden_size=self.head_dim,
eps=fd_config.model_config.rms_norm_eps,
prefix=f"{prefix}.k_norm",
begin_norm_axis=begin_norm_axis,
)
self.qk_norm_fused = current_platform.is_cuda() and _TRITON_AVAILABLE

def load_state_dict(self, state_dict):
self.q_norm.load_state_dict(state_dict)
self.k_norm.load_state_dict(state_dict)

def forward(
self,
qkv_out,
forward_meta,
) -> paddle.Tensor:
if self.qk_norm_fused and forward_meta.step_use_cudagraph:
qkv_out = qk_rmsnorm_fused(
qkv_out,
self.q_norm.weight,
self.k_norm.weight,
self.eps,
self.q_size,
self.kv_size,
self.head_dim,
)
else:
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], axis=-1)

q_by_head = q.reshape([*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim])
q_by_head = self.q_norm(q_by_head)[0]
q = q_by_head.reshape(q.shape)

k_by_head = k.reshape([*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim])
k_by_head = self.k_norm(k_by_head)[0]
k = k_by_head.reshape(k.shape)

qkv_out = paddle.concat([q, k, v], axis=-1)
return qkv_out


class LayerNorm(nn.Layer):
"""
Initializes the LayerNormalization layer
Expand Down
8 changes: 3 additions & 5 deletions fastdeploy/model_executor/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
"""

try:
from .qk_rmsnorm_fused_kernel import qk_rmsnorm_fused
from .repetition_early_stop_kernel import repetition_early_stopper_kernel
from .wint2_fused_moe import fused_moe_wint2_triton
from .wint2_fused_moe_kernel import moe_wint2_ffn_kernel

_TRITON_AVAILABLE = True

__all__ = ["moe_wint2_ffn_kernel", "repetition_early_stopper_kernel", "qk_rmsnorm_fused"]
__all__ = ["fused_moe_wint2_triton", "moe_wint2_ffn_kernel", "repetition_early_stopper_kernel"]
except:
_TRITON_AVAILABLE = False
pass
Loading
Loading