Skip to content

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Jan 29, 2026

Description

  • Added a new score func sqrtsoftplus
  • Add tests
  • Switch to FP32 math instead of FP64 for better performance on SM103. Precision principles:
    • Inputs are casted into FP32 when loading from global memory.
    • All the math/calculations are in FP32 in the kernels.
    • scores is always in FP32 (match the MCore implementation).
    • intermediate_output is always in FP32 for better backward precision.
    • Only cast to low-precision when necessary, for example, the gradient of input is required to have the same dtype as the input.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 self-assigned this Feb 6, 2026
@yaox12 yaox12 marked this pull request as ready for review February 6, 2026 06:41
@yaox12 yaox12 added the MoE label Feb 6, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds a new sqrtsoftplus score function to the fused MoE router, computed as sqrt(softplus(x)) = sqrt(log(1 + exp(x))). The implementation also switches all router math from FP64 to FP32 for better SM103 performance while maintaining precision.

Key Changes:

  • Added apply_sqrtsoftplus_on_float and apply_sqrtsoftplus_bwd_on_float functions in utils.h with numerically stable implementation (threshold=20.0 matching PyTorch)
  • Integrated sqrtsoftplus into forward/backward kernels for both fused_topk_with_score_function and fused_score_for_moe_aux_loss
  • For sqrtsoftplus backward pass, the implementation saves original logits in intermediate_output and recomputes the activation output during backward (compute-efficient approach)
  • Updated Python API, C++ bindings, and tests to support the new score function
  • Added comprehensive test coverage with various configurations (different topk, group_topk, expert_bias combinations)

Precision Strategy:
All computations now use FP32 throughout the kernels (inputs cast to FP32 on load, only cast back to input dtype when writing gradients). This improves performance on SM103 while maintaining numerical stability.

Previous Issues Addressed:
The developer has addressed all previously raised comments about error messages, gradient computation, and documentation accuracy.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations for numerical edge cases
  • The implementation is well-tested and mathematically sound. The sqrtsoftplus forward pass uses numerically stable softplus computation, and the backward gradient formula is correct. The switch to FP32 math improves performance while the code properly handles precision through careful casting. All previous review comments have been addressed. The score is 4 rather than 5 due to the complexity of the numerical computations and the potential for edge cases in extreme input ranges, though the implementation does include stability measures like epsilon additions and thresholding.
  • No files require special attention. The CUDA kernels in fused_topk_with_score_function.cu and fused_score_for_moe_aux_loss.cu are the most complex but appear correct with proper gradient computation and activation recomputation logic.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Added sqrtsoftplus activation function and its backward pass. Implementation looks correct with numerically stable softplus computation and proper gradient calculation.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Integrated sqrtsoftplus into forward/backward kernels. Correctly saves original logits to intermediate_output for backward pass and recomputes activation output during backward.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Added sqrtsoftplus support to aux loss score computation. Forward and backward passes correctly handle activation recomputation from saved logits.
transformer_engine/pytorch/csrc/extensions/router.cpp Updated C++ bindings to support sqrtsoftplus in score function map and validation. Correctly allows expert bias with both sigmoid and sqrtsoftplus.
transformer_engine/pytorch/router.py Updated Python API and documentation to support sqrtsoftplus. Added comprehensive precision notes explaining FP32 math strategy for SM103 performance.
tests/pytorch/test_fused_router.py Added comprehensive tests for sqrtsoftplus with various configurations. Reference implementation using PyTorch's softplus().sqrt() is correct.

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorchAPI as PyTorch API
    participant CPPExt as C Extension
    participant CUDAKernel as CUDA Kernel
    participant Memory

    User->>PyTorchAPI: Call with sqrtsoftplus
    PyTorchAPI->>CPPExt: Forward function call
    CPPExt->>CPPExt: Validate score function
    CPPExt->>Memory: Allocate FP32 buffers
    CPPExt->>CUDAKernel: Launch forward kernel
    
    Note over CUDAKernel: Forward Pass FP32 math
    CUDAKernel->>CUDAKernel: Load and cast logits to FP32
    CUDAKernel->>Memory: Save original logits
    CUDAKernel->>CUDAKernel: Apply sqrtsoftplus
    CUDAKernel->>CUDAKernel: Add expert bias
    CUDAKernel->>CUDAKernel: TopK selection
    CUDAKernel->>CUDAKernel: Normalize if needed
    CUDAKernel->>Memory: Write results
    
    CUDAKernel-->>CPPExt: Complete
    CPPExt-->>PyTorchAPI: Return tensors
    PyTorchAPI-->>User: Return probs and routing map
    
    User->>PyTorchAPI: Call backward
    PyTorchAPI->>CPPExt: Backward function call
    CPPExt->>CUDAKernel: Launch backward kernel
    
    Note over CUDAKernel: Backward Pass FP32 math
    CUDAKernel->>Memory: Load gradients and saved logits
    CUDAKernel->>CUDAKernel: Recompute sqrtsoftplus
    CUDAKernel->>CUDAKernel: Normalization backward
    CUDAKernel->>CUDAKernel: Sqrtsoftplus gradient
    CUDAKernel->>Memory: Write grad logits
    
    CUDAKernel-->>CPPExt: Complete
    CPPExt-->>PyTorchAPI: Return gradients
    PyTorchAPI-->>User: Propagate gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_router.h
Expert bias API mismatch

The header still documents expert_bias as "Only used at the sigmoid case" (line 27), but this PR enables expert_bias for sqrtsoftplus end-to-end (kernels handle score_function==2, and router.cpp allows it). Please update the API docs here (and any other public docs) to reflect that expert_bias is supported for sqrtsoftplus too, otherwise external callers will get contradictory guidance.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from denera February 10, 2026 22:37
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

tests/pytorch/test_fused_router.py
Missing expert_bias grad check

run_comparison() enables expert_bias.requires_grad = True for score_function in ("sigmoid", "sqrtsoftplus") (tests/pytorch/test_fused_router.py:177-180), but the test only asserts logits.grad equality (line ~230) and never checks expert_bias.grad vs expert_bias_clone.grad. This will let regressions in the fused backward w.r.t. expert_bias merge undetected for the bias-enabled paths.

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
@yaox12 yaox12 force-pushed the xiny/add_score_func branch from 70136ee to 3bace73 Compare February 11, 2026 09:53
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

transformer_engine/pytorch/router.py
Backward returns wrong arity

FusedTopkScoreFunction.forward takes 9 inputs after ctx (including expert_bias), but backward currently returns only 8 entries (grad_logits, None, None, None, None, None, None, None). Autograd expects one gradient slot per forward input, so this will raise at runtime. Add a trailing None for the expert_bias gradient (and keep the ordering aligned with the forward signature).

auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
auto score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
Copy link
Member Author

@yaox12 yaox12 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the output tensor on the Python side because the intermediate_output is now always in FP32, while the grad_logits should have the same dtype as the input.

Signed-off-by: Xin Yao <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant