Skip to content

[MTP] refactor MTP pre_process #6358

Open
zhoutianzi666 wants to merge 23 commits intoPaddlePaddle:developfrom
zhoutianzi666:remove_speculate_get_output_padding_offset
Open

[MTP] refactor MTP pre_process #6358
zhoutianzi666 wants to merge 23 commits intoPaddlePaddle:developfrom
zhoutianzi666:remove_speculate_get_output_padding_offset

Conversation

@zhoutianzi666
Copy link
Collaborator

@zhoutianzi666 zhoutianzi666 commented Feb 5, 2026

MTP前处理重构

Motivation

  • 这个PR把 MTP的 speculate_get_output_padding_offset 自定义算子删了,复用非MTP的 get_output_padding_offset,
  • 同时把 output_padding_offset 和 output_cum_offsets 这俩名字换成 语意更清晰的 batch_id_per_token_output 和 cu_seqlens_q_output了。

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

None

Usage or Command

None

Accuracy Tests

None

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Feb 5, 2026

Thanks for your contribution!

@zhoutianzi666 zhoutianzi666 changed the title commit rename Feb 5, 2026
@zhoutianzi666 zhoutianzi666 changed the title rename rename output_cum_offsets to cu_seqlens_q Feb 5, 2026
@zhoutianzi666 zhoutianzi666 changed the title rename output_cum_offsets to cu_seqlens_q rename output_cum_offsets to cu_seqlens_q , rename output_padding_offset to batch_id_per_token_output Feb 5, 2026
@zhoutianzi666 zhoutianzi666 changed the title rename output_cum_offsets to cu_seqlens_q , rename output_padding_offset to batch_id_per_token_output output_cum_offsets -> cu_seqlens_q , output_padding_offset -> batch_id_per_token_output Feb 5, 2026
K11OntheBoat added 2 commits February 5, 2026 19:29
…g_offset' into remove_speculate_get_output_padding_offset
@CLAassistant
Copy link

CLAassistant commented Feb 5, 2026

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
2 out of 3 committers have signed the CLA.

✅ EmmonsCurse
✅ zhoutianzi666
❌ K11OntheBoat


K11OntheBoat seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

@codecov-commenter
Copy link

codecov-commenter commented Feb 5, 2026

Codecov Report

❌ Patch coverage is 41.66667% with 14 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@d6b3c72). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/worker/input_batch.py 22.22% 12 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #6358   +/-   ##
==========================================
  Coverage           ?   67.70%           
==========================================
  Files              ?      391           
  Lines              ?    52243           
  Branches           ?     8149           
==========================================
  Hits               ?    35372           
  Misses             ?    14281           
  Partials           ?     2590           
Flag Coverage Δ
GPU 67.70% <41.66%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@zhoutianzi666 zhoutianzi666 changed the title output_cum_offsets -> cu_seqlens_q , output_padding_offset -> batch_id_per_token_output [MTP] refactor MTP pre_process Feb 7, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在重构 MTP/speculative decoding 的前处理与输出侧索引布局:删除 speculate_get_output_padding_offset 自定义算子,改为复用 get_padding_offset,并将输出相关字段重命名为更语义化的 batch_id_per_token_output / cu_seqlens_q_output,同时同步更新 CUDA kernels 与相关单测。

Changes:

  • 删除/弃用 speculate_get_output_padding_offset,输出侧 offset 改由 get_padding_offset 计算并贯穿传递
  • 将多个 CUDA 自定义算子输入从 output_* 切换为 batch_id_per_token_outputcu_seqlens_q_output
  • 更新 speculative decoding 相关测试用例与 runner/share_inputs 字段

Reviewed changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
tests/spec_decode/test_mtp_proposer.py 更新 MTP proposer 测试输入字段名以匹配新输出索引字段
tests/operators/test_speculate_verify.py 适配 speculate_verify 参数名变更为 cu_seqlens_q_output
tests/operators/test_speculate_get_token_penalty_multi_scores.py 适配 token penalty multi scores 的新输入字段
tests/operators/test_speculate_get_output_padding_offset.py 删除旧算子对应的单测文件
tests/operators/test_rebuild_padding.py 适配 rebuild_padding 新的输出索引输入(batch_id/cu_seqlens)
tests/operators/test_reasoning_phase_token_constraint.py get_padding_offset 替代旧输出 offset 算子,并适配新输入
tests/operators/test_draft_model_update.py output_cum_offsets 参数名替换为 cu_seqlens_q_output
tests/layers/test_speculative_sampler.py 更新 speculative sampler 的 share_inputs 构造以使用新字段
fastdeploy/worker/input_batch.py share_inputs 中新增/切换输出侧索引缓存字段(CUDA 分支)并在 swap 中处理
fastdeploy/worker/gpu_model_runner.py pre_process 返回值接入新字段,并在 speculative decoding 时写入 share_inputs
fastdeploy/spec_decode/mtp.py MTP proposer CUDA 路径接入新字段,并在 draft_model_update 侧做平台兼容选择
fastdeploy/model_executor/pre_and_post_process.py pre_process 输出侧索引改为复用 get_padding_offset;rebuild_padding 接口调整
fastdeploy/model_executor/layers/sample/sampler.py speculative sampler CUDA 路径改用新字段(verify/penalty/constraint/top_p_candidates)
fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py penalty/constraint 接口参数切换为新字段名
fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py 删除未使用的 zero_seq_enc_lens_for_decode 缓存
fastdeploy/model_executor/layers/attention/flash_attn_backend.py 删除未使用的 zero_seq_enc_lens_for_decode 缓存
custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu top_p_candidates 改用 batch_id_per_token_output 直接定位 batch
custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu speculate_verify 输入从 output_cum_offsets 切为 cu_seqlens_q_output
custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu penalty kernel 输入切换为新字段(batch_id/cu_seqlens)
custom_ops/gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu 删除旧输出 offset CUDA 实现文件
custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_update.cu draft_model_update 输入参数名切换为 cu_seqlens_q_output
custom_ops/gpu_ops/rebuild_padding.cu rebuild_padding 新增 cu_seqlens_q_output 可选输入,并改用 batch_id 计算
custom_ops/gpu_ops/reasoning_phase_token_constraint.cu reasoning constraint 输入切换为新字段(batch_id/cu_seqlens)
custom_ops/gpu_ops/cpp_extensions.cc 移除旧算子绑定,并更新相关函数声明
Comments suppressed due to low confidence (1)

custom_ops/gpu_ops/cpp_extensions.cc:1614

  • 这里删除了 speculate_get_output_padding_offset 的 Python 绑定/实现后,还需要同步更新构建脚本里对应的源文件列表;当前 custom_ops/setup_ops.py 的 ROCm sources 仍包含 gpu_ops/speculate_decoding/speculate_get_output_padding_offset.cu,该文件已在本 PR 中删除,会导致 ROCm 编译直接失败。建议一并移除该 source(或替换为新的实现文件)。
  m.def("speculate_get_seq_lens_output",
        &SpeculateGetSeqLensOutput,
        "speculate_get_seq_lens_output function");

  m.def("speculate_get_token_penalty_multi_scores",
        &SpecTokenPenaltyMultiScores,
        "speculate_get_token_penalty_multi_scores function");

Comment on lines 226 to 275
specific_platform = current_platform.is_cuda() or current_platform.is_maca() or current_platform.is_iluvatar()
if specific_platform and not speculative_decoding:
# Note(ZKK): This case's code is very simple!
ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k = get_padding_offset(
input_ids, seq_lens_this_time, None, None, token_num_cpu
)
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
None,
None,
)
# Remove padding
max_len = input_ids.shape[1]
cum_offsets_now = paddle.cumsum(max_len - seq_lens_this_time, dtype="int32")
output_padding_offset = None
output_cum_offsets = None
if speculative_decoding:
(
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, seq_lens_this_time, draft_tokens, seq_lens_encoder, token_num_cpu)

# compute each batch's output token num
seq_lens_output = speculate_get_seq_lens_output(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
)
if isinstance(seq_lens_output, list):
seq_lens_output = seq_lens_output[0]
output_token_num = paddle.sum(seq_lens_output)
output_cum_offsets_tmp = paddle.cumsum(max_len - seq_lens_output, dtype="int32")
output_padding_offset, output_cum_offsets = speculate_get_output_padding_offset(
output_cum_offsets_tmp,
output_token_num,

useless_input_ids = input_ids + 0
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
max_len,
None,
None,
output_token_num.item(),
)
else:
token_num = paddle.sum(seq_lens_this_time)
(
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)

return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
output_padding_offset,
cu_seqlens_q_output,
batch_id_per_token_output,
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_process 当前在 non-CUDA/非 maca/非 iluvatar 平台且 speculative_decoding=False 时不会走 early return,也不会进入 speculative_decoding 分支,导致 ids_remove_padding/batch_id_per_token 等变量未定义就直接 return(运行时会报错)。建议补齐该分支的 get_padding_offset 调用/返回值,或显式仅在支持的平台调用该实现。

Copilot uses AI. Check for mistakes.
Comment on lines 268 to 275
return (
ids_remove_padding,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
output_padding_offset,
cu_seqlens_q_output,
batch_id_per_token_output,
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre_process 的返回值顺序/数量已经变更为 (.., cu_seqlens_q_output, batch_id_per_token_output),但仓库里仍有调用方(例如 GCU/Metax runner)按旧的 7 个返回值解包,会直接触发解包异常。建议保持向后兼容的返回结构(或同时更新所有调用方并做平台区分)。

Copilot uses AI. Check for mistakes.
Comment on lines 259 to 266
useless_input_ids = input_ids + 0
_, batch_id_per_token_output, cu_seqlens_q_output, _ = get_padding_offset(
useless_input_ids,
seq_lens_output,
max_len,
None,
None,
output_token_num.item(),
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为了计算输出侧 offset 构造了 useless_input_ids = input_ids + 0,会额外产生一次 tensor 计算/拷贝(且可能触发不必要的 kernel)。get_padding_offset 内核只读 input_ids 的话可以直接复用 input_ids(或用更轻量的方式构造占位输入),避免额外开销。

Copilot uses AI. Check for mistakes.
swap_data(self.actual_draft_token_num, i1, i2)
swap_data(self.output_cum_offsets, i1, i2)
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

swap_states 在 speculative_decoding + CUDA 分支只交换了 cu_seqlens_q_output,没有交换 batch_id_per_token_output。两者是一组配套的输出侧索引信息,如果在 reorder/condense 之后到下一次重算之前被读取,会产生不一致。建议同时 swap batch_id_per_token_output,或明确保证该 buffer 不会在交换后被消费。

Suggested change
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.batch_id_per_token_output, i1, i2)

Copilot uses AI. Check for mistakes.
Comment on lines 135 to 150
@@ -144,8 +145,8 @@ std::vector<paddle::Tensor> rebuild_padding(
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
output_padding_offset.get_ptr()->data<int>(),
max_input_length,
batch_id_per_token_output.get_ptr()->data<int>(),
cu_seqlens_q_output.get_ptr()->data<int>(),
dim_embed,
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里在 batch_id_per_token_output 非空时无条件解引用 cu_seqlens_q_output.get_ptr()(第149行)。如果 Python 侧只传了 batch_id_per_token_output 而未传 cu_seqlens_q_output,会直接触发空指针/未初始化访问。建议在 C++ 侧增加 PD_CHECK(cu_seqlens_q_output, ...) 或在 else 分支里给出合理默认值,并在算子接口文档/调用方强制成对传入。

Copilot uses AI. Check for mistakes.
dtype="int32",
)
if current_platform.is_cuda():
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA 下 cu_seqlens_q_output 从 get_padding_offset 返回的是一维 [bsz+1],但这里初始化为 shape=[max_num_seqs+1, 1](二维)。后续在 gpu_model_runner 里对其执行 copy_ 很可能因为 shape 不匹配而报错。建议初始化为一维 shape=[max_num_seqs+1],与 cu_seqlens_q 的约定保持一致。

Suggested change
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32")
self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")

Copilot uses AI. Check for mistakes.
Comment on lines +111 to +114
cu_seqlens_q_output = [0] + paddle.cumsum(share_inputs["seq_lens_this_time"]).numpy().tolist()
share_inputs["cu_seqlens_q_output"] = paddle.to_tensor(cu_seqlens_q_output).cast("int32")
share_inputs["batch_id_per_token_output"] = paddle.arange(max_num_seqs, dtype="int32") * 2

Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 share_inputs["seq_lens_this_time"] 是二维 [N,1],直接 paddle.cumsum(...).numpy().tolist() 会得到形如 [[2],[4],...] 的嵌套列表,拼接 [0] + ... 后再 to_tensor 很容易产生 dtype/shape 异常。另外 batch_id_per_token_output 需要是长度为 token_num 的逐 token batch id(通常是对 batch id 按 seq_len repeat),这里用 arange(max_num_seqs)*2 长度不匹配且语义不对。建议按自定义算子期望构造一维 cu_seqlens_q_output(长度 N+1)和逐 token 的 batch_id_per_token_output。

Suggested change
cu_seqlens_q_output = [0] + paddle.cumsum(share_inputs["seq_lens_this_time"]).numpy().tolist()
share_inputs["cu_seqlens_q_output"] = paddle.to_tensor(cu_seqlens_q_output).cast("int32")
share_inputs["batch_id_per_token_output"] = paddle.arange(max_num_seqs, dtype="int32") * 2
# seq_lens_this_time is [N, 1], convert to [N] before cumsum to build 1D cu_seqlens_q_output
seq_lens_1d = share_inputs["seq_lens_this_time"].squeeze(1)
cu_seqlens_q_output = paddle.concat(
[paddle.zeros([1], dtype="int32"), paddle.cumsum(seq_lens_1d)]
).astype("int32")
share_inputs["cu_seqlens_q_output"] = cu_seqlens_q_output
# batch_id_per_token_output should be per-token batch ids with length equal to total token num
batch_ids = paddle.arange(max_num_seqs, dtype="int32")
share_inputs["batch_id_per_token_output"] = paddle.repeat_interleave(batch_ids, seq_lens_1d)

Copilot uses AI. Check for mistakes.
Comment on lines +212 to +226
# prepare batch_id_per_token_output and cu_seqlens_q_output
tokens = [1] * bs
token_num = np.sum(tokens)
output_padding_offset = []
output_cum_offsets = [0]
batch_id_per_token_output = []
cu_seqlens_q_output = [0]
opo_offset = 0
for bid in range(bs):
ts = tokens[bid]
for i in range(ts):
output_padding_offset.append(opo_offset)
batch_id_per_token_output.append(opo_offset)
opo_offset += max_seq_len - ts
output_cum_offsets.append(opo_offset)
output_cum_offsets = output_cum_offsets[:-1]
output_padding_offset = paddle.to_tensor(output_padding_offset, "int32")
output_cum_offsets = paddle.to_tensor(output_cum_offsets, "int32")
cu_seqlens_q_output.append(opo_offset)
cu_seqlens_q_output = cu_seqlens_q_output[:-1]
batch_id_per_token_output = paddle.to_tensor(batch_id_per_token_output, "int32")
cu_seqlens_q_output = paddle.to_tensor(cu_seqlens_q_output, "int32")
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里生成的 batch_id_per_token_output 实际写入的是 padding offset(opo_offset),而不是逐 token 的 batch id;同时 cu_seqlens_q_output 也在累加 (max_seq_len - ts) 的 padding 长度。新的 CUDA kernel 里会直接用 batch_id_per_token_output[token_idx] 当作 batch 索引,并用 cu_seqlens_q_output[bi] 当作该 batch 的起始 token id,所以这两者的构造方式需要改为:batch_id_per_token_output 存 [0..bs-1](按 token repeat),cu_seqlens_q_output 存累计 token 数(长度 bs+1 或 bs,按接口约定)。否则测试很可能变成“几乎不覆盖逻辑”的假阳性。

Copilot uses AI. Check for mistakes.
Comment on lines +277 to 282
# cu_seqlens_q_output = np.zeros_like(seq_lens_this_time)
# cu_seqlens_q_output[1:] = np.cumsum(seq_lens_this_time[:-1])
blank_lengths = max_seq_len - seq_lens_this_time
output_cum_offsets = np.concatenate([[0], np.cumsum(blank_lengths[:-1])])
output_cum_offsets = output_cum_offsets.astype("int32")
cu_seqlens_q_output = np.concatenate([[0], np.cumsum(blank_lengths[:-1])])
cu_seqlens_q_output = cu_seqlens_q_output.astype("int32")
actual_candidate_len = rng.integers(1, max_candidate_len + 1, size=sum_seq_this_time, dtype=np.int32)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_seqlens_q_output 在 kernel/参考实现里被当作每个 batch 在 packed 输出中的起始 token index(等价于对 seq_lens_this_time 的前缀和),但这里用的是 blank_lengths = max_seq_len - seq_lens_this_time 的累加,语义仍然是旧的 output_cum_offsets(padding 累计)。这会导致 start_token_id 计算错误。建议改为对 seq_lens_this_time 做前缀和(或直接复用 get_padding_offset 的输出)来生成 cu_seqlens_q_output。

Copilot uses AI. Check for mistakes.
Comment on lines +80 to 87
useless_inputs = paddle.zeros([self.bs, self.max_seq_len], dtype="int64")
_, self.output_padding_offset, self.output_cum_offsets, _ = get_padding_offset(
useless_inputs,
seq_lens_output,
self.max_seq_len,
None,
None,
output_token_num.item(),
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 get_padding_offset 的返回值实际是 (x_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k),但变量名仍叫 output_padding_offset/output_cum_offsets,和算子/内核里新的语义(batch_id_per_token_output、cu_seqlens_q_output)不一致,容易误导后续维护。建议把测试里的字段名同步改为 batch_id_per_token_output 和 cu_seqlens_q_output,并相应更新后续调用处,避免语义混淆。

Copilot generated this review using guidance from repository custom instructions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants