Skip to content

[Pytorch] Add get_backward_dw_params api for TE module#2614

Merged
timmoon10 merged 4 commits intoNVIDIA:mainfrom
Wohox:pingtian/add_linear_wgrad_compute_param_api
Feb 9, 2026
Merged

[Pytorch] Add get_backward_dw_params api for TE module#2614
timmoon10 merged 4 commits intoNVIDIA:mainfrom
Wohox:pingtian/add_linear_wgrad_compute_param_api

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

This PR refactors weight gradient hook management by extracting hook trigger logic into a centralized _trigger_wgrad_accumulation_and_reduce_hooks() method and ensuring hooks are called after CUDA graph replay.

Key Changes:

  • Extracted hook triggering into _trigger_wgrad_accumulation_and_reduce_hooks() in TransformerEngineBaseModule (base.py:1531-1536)
  • Updated backward_dw() in base.py, GroupedLinear, and LayerNormMLP to use the new method
  • Added hook triggering after wgrad graph replay in graph.py (lines 864-870) to fix missing grad reduce when using CUDA graphs with Megatron-LM

Why This Matters:
The PR fixes a bug where backward_post_hook for wgrad computation was being discarded during CUDA graph replay, causing parameters to skip grad reduce in Megatron-LM. The refactoring centralizes hook management and ensures hooks run consistently whether using direct backward_dw() calls or CUDA graph replay.

Previous Review Concerns Addressed:
All previously raised concerns have been addressed by the developer through code revisions and clarifications.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The changes are a clean refactoring that extracts duplicate code into a reusable method and fixes a specific integration bug with Megatron-LM. The instanceof check at graph.py:867 safely filters modules before calling methods, preventing errors. All previous review concerns have been addressed.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Extracts hook trigger logic into _trigger_wgrad_accumulation_and_reduce_hooks() method for better code organization and reusability
transformer_engine/pytorch/graph.py Adds hook triggering after wgrad graph replay to ensure grad reduce hooks run when using CUDA graphs
transformer_engine/pytorch/module/grouped_linear.py Replaces inline hook loop with centralized _trigger_wgrad_accumulation_and_reduce_hooks() method call
transformer_engine/pytorch/module/layernorm_mlp.py Replaces inline hook loop with centralized _trigger_wgrad_accumulation_and_reduce_hooks() method call

Sequence Diagram

sequenceDiagram
    participant User as User/Megatron-LM
    participant Graph as CUDA Graph (graph.py)
    participant Module as TE Module
    participant Hooks as Wgrad Hooks

    User->>Graph: Call backward_dw() on graphed callable
    Graph->>Graph: Replay bwd_dw_graph
    Graph->>Module: Check if isinstance(TransformerEngineBaseModule)
    Graph->>Module: Check need_backward_dw()
    alt Module needs backward_dw
        Graph->>Module: _trigger_wgrad_accumulation_and_reduce_hooks()
        Module->>Hooks: Execute each registered hook
        Hooks-->>User: Grad reduce completed
    end
    
    Note over Graph,Module: Direct backward_dw() path (no graph)
    User->>Module: Call backward_dw() directly
    Module->>Module: Process wgrad computation
    Module->>Module: _trigger_wgrad_accumulation_and_reduce_hooks()
    Module->>Hooks: Execute each registered hook
    Hooks-->>User: Grad reduce completed
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit content reverted.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@Wohox
Copy link
Contributor Author

Wohox commented Jan 30, 2026

@ksivaman Can you help review this PR, it's a bug fix for #2376.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@lhb8125
Copy link
Contributor

lhb8125 commented Feb 5, 2026

/te-ci pytorch L1

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
@Wohox Wohox force-pushed the pingtian/add_linear_wgrad_compute_param_api branch from 0906e63 to 5dfe8c1 Compare February 5, 2026 03:30
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
@Wohox
Copy link
Contributor Author

Wohox commented Feb 5, 2026

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@lhb8125 lhb8125 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@Wohox
Copy link
Contributor Author

Wohox commented Feb 9, 2026

@timmoon10 Can you help review this MR, thanks~

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1531 to +1536
def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional hook trigger
TransformerEngineBaseModule.backward_dw() now always calls _trigger_wgrad_accumulation_and_reduce_hooks() (transformer_engine/pytorch/module/base.py:1529-1536). In the fuse_wgrad_accumulation=True path the module does not materialize weight_tensor.grad (base.py:1520-1522), so these hooks may run without the expected grads present. If any hook assumes .grad exists (common for grad-reduce hooks), this will raise at runtime specifically when delayed wgrad compute is enabled with fused wgrad accumulation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hooks are specific to Megatron-LM. They are delicate interfaces for expert users.

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

Comment on lines +1531 to +1536
def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These hooks are specific to Megatron-LM. They are delicate interfaces for expert users.

@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@timmoon10

This comment was marked as outdated.

1 similar comment
@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

@timmoon10 timmoon10 merged commit 2894e49 into NVIDIA:main Feb 9, 2026
27 of 32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments