[Pytorch] Add get_backward_dw_params api for TE module#2614
[Pytorch] Add get_backward_dw_params api for TE module#2614timmoon10 merged 4 commits intoNVIDIA:mainfrom
Conversation
Greptile OverviewGreptile SummaryThis PR refactors weight gradient hook management by extracting hook trigger logic into a centralized Key Changes:
Why This Matters: Previous Review Concerns Addressed: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/Megatron-LM
participant Graph as CUDA Graph (graph.py)
participant Module as TE Module
participant Hooks as Wgrad Hooks
User->>Graph: Call backward_dw() on graphed callable
Graph->>Graph: Replay bwd_dw_graph
Graph->>Module: Check if isinstance(TransformerEngineBaseModule)
Graph->>Module: Check need_backward_dw()
alt Module needs backward_dw
Graph->>Module: _trigger_wgrad_accumulation_and_reduce_hooks()
Module->>Hooks: Execute each registered hook
Hooks-->>User: Grad reduce completed
end
Note over Graph,Module: Direct backward_dw() path (no graph)
User->>Module: Call backward_dw() directly
Module->>Module: Process wgrad computation
Module->>Module: _trigger_wgrad_accumulation_and_reduce_hooks()
Module->>Hooks: Execute each registered hook
Hooks-->>User: Grad reduce completed
|
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
There was a problem hiding this comment.
logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True
There was a problem hiding this comment.
commit content reverted.
|
/te-ci pytorch L1 |
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
0906e63 to
5dfe8c1
Compare
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
|
/te-ci pytorch L1 |
|
@timmoon10 Can you help review this MR, thanks~ |
| def _trigger_wgrad_accumulation_and_reduce_hooks(self): | ||
| """ | ||
| Trigger the wgrad accumulation and reduce hooks. | ||
| """ | ||
| for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: | ||
| wgrad_accumulation_and_reduce_hook() |
There was a problem hiding this comment.
Unconditional hook trigger
TransformerEngineBaseModule.backward_dw() now always calls _trigger_wgrad_accumulation_and_reduce_hooks() (transformer_engine/pytorch/module/base.py:1529-1536). In the fuse_wgrad_accumulation=True path the module does not materialize weight_tensor.grad (base.py:1520-1522), so these hooks may run without the expected grads present. If any hook assumes .grad exists (common for grad-reduce hooks), this will raise at runtime specifically when delayed wgrad compute is enabled with fused wgrad accumulation.
There was a problem hiding this comment.
These hooks are specific to Megatron-LM. They are delicate interfaces for expert users.
| def _trigger_wgrad_accumulation_and_reduce_hooks(self): | ||
| """ | ||
| Trigger the wgrad accumulation and reduce hooks. | ||
| """ | ||
| for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: | ||
| wgrad_accumulation_and_reduce_hook() |
There was a problem hiding this comment.
These hooks are specific to Megatron-LM. They are delicate interfaces for expert users.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
1 similar comment
|
/te-ci pytorch L1 |
Description
This PR adds
get_backward_dw_paramsfor TE modules, which helps manage the hooks of parameters.For Megatron-LM,
get_backward_dw_paramswill be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: