-
Notifications
You must be signed in to change notification settings - Fork 637
[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router
#2633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile OverviewGreptile SummaryThis PR adds a new Key Changes:
Precision Strategy: Previous Issues Addressed: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorchAPI as PyTorch API
participant CPPExt as C Extension
participant CUDAKernel as CUDA Kernel
participant Memory
User->>PyTorchAPI: Call with sqrtsoftplus
PyTorchAPI->>CPPExt: Forward function call
CPPExt->>CPPExt: Validate score function
CPPExt->>Memory: Allocate FP32 buffers
CPPExt->>CUDAKernel: Launch forward kernel
Note over CUDAKernel: Forward Pass FP32 math
CUDAKernel->>CUDAKernel: Load and cast logits to FP32
CUDAKernel->>Memory: Save original logits
CUDAKernel->>CUDAKernel: Apply sqrtsoftplus
CUDAKernel->>CUDAKernel: Add expert bias
CUDAKernel->>CUDAKernel: TopK selection
CUDAKernel->>CUDAKernel: Normalize if needed
CUDAKernel->>Memory: Write results
CUDAKernel-->>CPPExt: Complete
CPPExt-->>PyTorchAPI: Return tensors
PyTorchAPI-->>User: Return probs and routing map
User->>PyTorchAPI: Call backward
PyTorchAPI->>CPPExt: Backward function call
CPPExt->>CUDAKernel: Launch backward kernel
Note over CUDAKernel: Backward Pass FP32 math
CUDAKernel->>Memory: Load gradients and saved logits
CUDAKernel->>CUDAKernel: Recompute sqrtsoftplus
CUDAKernel->>CUDAKernel: Normalization backward
CUDAKernel->>CUDAKernel: Sqrtsoftplus gradient
CUDAKernel->>Memory: Write grad logits
CUDAKernel-->>CPPExt: Complete
CPPExt-->>PyTorchAPI: Return gradients
PyTorchAPI-->>User: Propagate gradients
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 2 comments
Additional Comments (1)
The header still documents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
Additional Comments (1)
|
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
70136ee to
3bace73
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 3 comments
transformer_engine/common/fused_router/fused_topk_with_score_function.cu
Show resolved
Hide resolved
Additional Comments (1)
|
| auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; | ||
| auto score_function_value = score_function_map[score_function]; | ||
| // Init the output tensor | ||
| at::Tensor grad_logits = at::empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating the output tensor on the Python side because the intermediate_output is now always in FP32, while the grad_logits should have the same dtype as the input.
Signed-off-by: Xin Yao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, no comments
Description
sqrtsoftplusscoresis always in FP32 (match the MCore implementation).intermediate_outputis always in FP32 for better backward precision.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: