-
Notifications
You must be signed in to change notification settings - Fork 609
[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell #2584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables deterministic FusedAttention on Blackwell GPUs (sm100+) for FP16/BF16 precisions with cuDNN >= 9.18.1. Key changes:
Implementation correctly handles:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant PyTorch/JAX
participant Backend Selection
participant cuDNN
User->>PyTorch/JAX: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
PyTorch/JAX->>Backend Selection: get_fused_attn_backend(deterministic=true)
alt Blackwell (sm100+) && training
Backend Selection->>Backend Selection: Check cuDNN version
alt cuDNN >= 9.18.1 && no_bias && no_dropout
Backend Selection->>Backend Selection: Enable arbitrary_seqlen backend
Backend Selection-->>PyTorch/JAX: NVTE_F16_arbitrary_seqlen
else cuDNN < 9.18.1 || bias || dropout
Backend Selection->>Backend Selection: Disable arbitrary_seqlen backend
Backend Selection-->>PyTorch/JAX: NVTE_F16_max512_seqlen or No_Backend
end
else Non-Blackwell
Backend Selection->>Backend Selection: Enable arbitrary_seqlen backend
Backend Selection-->>PyTorch/JAX: NVTE_F16_arbitrary_seqlen
end
PyTorch/JAX->>cuDNN: Forward pass (always deterministic)
cuDNN-->>PyTorch/JAX: Forward result + aux tensors
PyTorch/JAX->>cuDNN: Backward pass (deterministic if requested)
cuDNN-->>PyTorch/JAX: Backward gradients
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Greptile OverviewGreptile SummaryOverviewThis PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer. Key Changes
ArchitectureThe change follows a layered approach:
The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism. Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User/Test
participant PyAPI as Python API
participant Utils as utils.py
participant CppExt as C++ Extensions
participant Backend as Backend Selection
participant cuDNN as cuDNN Library
User->>PyAPI: Call attention with deterministic=True
PyAPI->>Utils: get_attention_backend(params)
Utils->>Utils: Extract deterministic from params
Utils->>CppExt: get_fused_attn_backend(..., deterministic)
CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
alt Blackwell (sm_arch >= 100) & Training & Deterministic
Backend->>Backend: Check cuDNN version >= 9.18.0
Backend->>Backend: Check bias_type == NO_BIAS
Backend->>Backend: Check dropout == 0.0
alt All checks pass
Backend-->>CppExt: F16_arbitrary_seqlen backend
else Any check fails
Backend-->>CppExt: No_Backend (disabled)
end
else Other architectures or inference
Backend->>Backend: Apply standard backend selection
Backend-->>CppExt: Selected backend
end
CppExt-->>Utils: Backend choice
Utils-->>PyAPI: Backend configuration
alt Forward Pass
PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
Note over PyAPI,CppExt: Forward always uses deterministic=true
else Backward Pass
PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
Note over PyAPI,CppExt: Backward respects user's deterministic flag
end
CppExt->>cuDNN: Execute attention operation
cuDNN-->>CppExt: Results
CppExt-->>PyAPI: Output tensors
PyAPI-->>User: Attention output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
make .xml file specific to deterministic tests in qa/ Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
fix typo Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
fix indentation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
/te-ci L0 |
Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci jax L0 |
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L0 |
|
/te-ci L1 |
Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 3 comments
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
fix and/or logic Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Charlene Yang <[email protected]>
|
/te-ci L1 |
|
Cool, we are currently suffering from this issue. |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.
It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.
I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .
Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped
| mkdir -p "$XML_LOG_DIR" | ||
|
|
||
| python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" | ||
| NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?
That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this extra line tests test_fused_attn.py with determinism, while the line before tests everything with non-determinism. The extra test_fused_attn.py test takes ~20mins on Blackwell:
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_backward | 5040x | 1336.28s | avg: 0.27s
================================================================================
TOTAL RUNTIME | | 1336.28s |
================================================================================
| float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, | ||
| size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph); | ||
| int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I felt there was a bit of distinction when I was implementing it: is_training is a description of the state we are in while deterministic is more of a request from the user (that they want to run it in deterministic mode). Not a lot of difference, to be honest - just a feel of the words. I kind of did this when I introduced deterministic as a parameter in AttentionParams so just followed along with it in this PR. Any strong objections?
|
/te-ci L0 L1 |
Signed-off-by: Charlene Yang <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
|
Pipeline 42017245 for CI with updated cuDNN. |
Description
This PR enables determinism for FP16/BF16 attention on Blackwell. It requires cuDNN >= 9.18.1.
To run determinism, please set
export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.Support matrix for FP16/BF16 on Blackwell:
Type of change
Changes
Please see Description.
Checklist: