Skip to content

Conversation

@cyanguwa
Copy link
Collaborator

@cyanguwa cyanguwa commented Jan 12, 2026

Description

This PR enables determinism for FP16/BF16 attention on Blackwell. It requires cuDNN >= 9.18.1.

To run determinism, please set export NVTE_ALLOW_NONDETERMINISTIC_ALGO=0.

Support matrix for FP16/BF16 on Blackwell:

  • cuDNN 9.7.0-9.18.0: non-determinism, dbias without dropout
  • cuDNN 9.18.1+: non-determinism, dbias without dropout; determinism, no dbias or dropout

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please see Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa cyanguwa changed the title [Common] Enable determinism for SDPA on Blackwell [Common] Enable determinism for cuDNN >= 9.18 on Blackwell Jan 12, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Summary

This PR enables deterministic FusedAttention on Blackwell GPUs (sm100+) for FP16/BF16 precisions with cuDNN >= 9.18.1.

Key changes:

  • Adds deterministic parameter to nvte_get_fused_attn_backend() API that controls backend selection based on determinism requirements
  • Updates backend selection logic in fused_attn.cpp:447-452 to enable arbitrary_seqlen backend on Blackwell for:
    • Non-deterministic mode (cuDNN >= 9.7): requires either no bias OR no dropout
    • Deterministic mode (cuDNN >= 9.18.1): requires no bias AND no dropout
  • Forward passes always use deterministic=false since forward implementation is always deterministic; backward passes use the actual user setting
  • Removes blanket Blackwell training disable from PyTorch utils since determinism is now supported
  • Adds deterministic test runs to both PyTorch and JAX test suites with separate XML output files
  • Updates cudnn-frontend submodule to support new determinism features

Implementation correctly handles:

  • The asymmetry between forward (always deterministic) and backward (optional deterministic) passes is intentional per cuDNN implementation
  • Logic uses OR for non-deterministic (bias XOR dropout allowed) and AND for deterministic (neither allowed), matching cuDNN constraints
  • All previous review threads have been addressed or resolved

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • All changes are well-structured additions that enable deterministic execution on Blackwell GPUs. The implementation correctly handles the complex logic for backend selection based on cuDNN version, determinism requirements, and parameter constraints. Previous review comments have been addressed. Test coverage has been extended with dedicated deterministic test runs.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds deterministic parameter to backend selection function and properly threads it through forward/backward passes. Forward passes use deterministic=false while backward passes use the actual parameter value, which is correct per developer's explanation that forward is always deterministic.
transformer_engine/jax/cpp_extensions/attention.py Updates Blackwell determinism assertions to support cuDNN >= 9.18.1 for deterministic backward pass. Logic correctly implements constraints: non-deterministic requires (bias XOR dropout), deterministic requires (no bias AND no dropout).
tests/jax/test_fused_attn.py Updates test skip conditions to match new Blackwell determinism support. Properly separates non-deterministic and deterministic test paths with correct version and parameter checks.
tests/pytorch/attention/test_attention.py Adds deterministic flag to backend availability checks and includes new deterministic test run. Properly sets NVTE_UNFUSED_ATTN environment variable for consistency.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Passes deterministic flag to backend selection and removes blanket Blackwell training disable since determinism is now supported with cuDNN 9.18+.

Sequence Diagram

sequenceDiagram
    participant User
    participant PyTorch/JAX
    participant Backend Selection
    participant cuDNN

    User->>PyTorch/JAX: Set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
    PyTorch/JAX->>Backend Selection: get_fused_attn_backend(deterministic=true)
    
    alt Blackwell (sm100+) && training
        Backend Selection->>Backend Selection: Check cuDNN version
        alt cuDNN >= 9.18.1 && no_bias && no_dropout
            Backend Selection->>Backend Selection: Enable arbitrary_seqlen backend
            Backend Selection-->>PyTorch/JAX: NVTE_F16_arbitrary_seqlen
        else cuDNN < 9.18.1 || bias || dropout
            Backend Selection->>Backend Selection: Disable arbitrary_seqlen backend
            Backend Selection-->>PyTorch/JAX: NVTE_F16_max512_seqlen or No_Backend
        end
    else Non-Blackwell
        Backend Selection->>Backend Selection: Enable arbitrary_seqlen backend
        Backend Selection-->>PyTorch/JAX: NVTE_F16_arbitrary_seqlen
    end
    
    PyTorch/JAX->>cuDNN: Forward pass (always deterministic)
    cuDNN-->>PyTorch/JAX: Forward result + aux tensors
    
    PyTorch/JAX->>cuDNN: Backward pass (deterministic if requested)
    cuDNN-->>PyTorch/JAX: Backward gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Greptile Overview

Greptile Summary

Overview

This PR enables determinism for FusedAttention on Blackwell GPUs (SM 100) with cuDNN version 9.18.0 or higher. The implementation moves determinism checking logic from Python to the C++ backend selection layer.

Key Changes

  1. Backend Selection Logic: Added a new condition in nvte_get_fused_attn_backend() that disables the arbitrary sequence length backend for Blackwell when:

    • Training mode is enabled
    • Determinism is required
    • Any of: cuDNN < 9.18.0, bias is used, or dropout > 0
  2. API Updates: Added deterministic parameter to the backend selection function across Python, C++, and JAX interfaces. Forward passes hardcode deterministic=true while backward passes accept it as a parameter.

  3. Code Migration: Moved Blackwell determinism checks from Python (utils.py) to C++ backend selection, consolidating version, bias, and dropout checks in one place.

  4. Test Infrastructure: Added environment variable NVTE_ALLOW_NONDETERMINISTIC_ALGO to control determinism in tests, and added explicit NVTE_UNFUSED_ATTN=0 settings to ensure proper backend isolation.

  5. Dependency Update: Updated cudnn-frontend submodule to version 1.17 to support the new determinism features.

Architecture

The change follows a layered approach:

  • User API Level: Python tests set deterministic flag via environment variable or torch settings
  • Python Layer: Extracts deterministic flag and passes to C++ extension
  • C++ Backend Selection: Evaluates hardware, cuDNN version, bias, and dropout to determine if deterministic FusedAttention is supported
  • Execution: If requirements aren't met, falls back to other backends (FlashAttention or UnfusedDotProductAttention)

The implementation correctly restricts deterministic FusedAttention to cases where cuDNN guarantees deterministic behavior, avoiding silent non-determinism.

Confidence Score: 4/5

  • This PR is safe to merge with minor issues that should be addressed
  • The implementation is sound and correctly adds determinism support for Blackwell GPUs. The core logic properly checks cuDNN version, bias, and dropout constraints. However, two issues lower the confidence: (1) inconsistent tab/space indentation in the critical condition on line 444 of fused_attn.cpp, and (2) duplicate XML output file in test.sh causing test results to be overwritten. Both are non-critical but should be fixed before merge.
  • Pay attention to transformer_engine/common/fused_attn/fused_attn.cpp (line 444 indentation) and qa/L0_pytorch_unittest/test.sh (line 48 XML filename collision)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 4/5 Added determinism check for Blackwell (sm100) to disable FusedAttention when cuDNN < 9.18.0 or bias/dropout are used. Contains tab indentation inconsistency on line 444.
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed Python-side Blackwell determinism check, now handled in C++. Added deterministic parameter to backend selection call.
tests/pytorch/attention/test_attention.py 5/5 Added deterministic flag from environment variable and torch settings. Updated tests to explicitly set NVTE_UNFUSED_ATTN=0 to ensure correct backend isolation.
qa/L0_pytorch_unittest/test.sh 3/5 Added deterministic test run with NVTE_ALLOW_NONDETERMINISTIC_ALGO=0. Both test runs write to same XML file causing results to be overwritten.

Sequence Diagram

sequenceDiagram
    participant User as User/Test
    participant PyAPI as Python API
    participant Utils as utils.py
    participant CppExt as C++ Extensions
    participant Backend as Backend Selection
    participant cuDNN as cuDNN Library

    User->>PyAPI: Call attention with deterministic=True
    PyAPI->>Utils: get_attention_backend(params)
    Utils->>Utils: Extract deterministic from params
    Utils->>CppExt: get_fused_attn_backend(..., deterministic)
    CppExt->>Backend: nvte_get_fused_attn_backend(..., deterministic)
    
    alt Blackwell (sm_arch >= 100) & Training & Deterministic
        Backend->>Backend: Check cuDNN version >= 9.18.0
        Backend->>Backend: Check bias_type == NO_BIAS
        Backend->>Backend: Check dropout == 0.0
        alt All checks pass
            Backend-->>CppExt: F16_arbitrary_seqlen backend
        else Any check fails
            Backend-->>CppExt: No_Backend (disabled)
        end
    else Other architectures or inference
        Backend->>Backend: Apply standard backend selection
        Backend-->>CppExt: Selected backend
    end
    
    CppExt-->>Utils: Backend choice
    Utils-->>PyAPI: Backend configuration
    
    alt Forward Pass
        PyAPI->>CppExt: nvte_fused_attn_fwd(..., deterministic=true)
        Note over PyAPI,CppExt: Forward always uses deterministic=true
    else Backward Pass
        PyAPI->>CppExt: nvte_fused_attn_bwd(..., deterministic)
        Note over PyAPI,CppExt: Backward respects user's deterministic flag
    end
    
    CppExt->>cuDNN: Execute attention operation
    cuDNN-->>CppExt: Results
    CppExt-->>PyAPI: Output tensors
    PyAPI-->>User: Attention output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

make .xml file specific to deterministic tests in qa/

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 13, 2026 06:00
fix typo

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
fix indentation

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

@cyanguwa
Copy link
Collaborator Author

/te-ci jax L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L0

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

/te-ci L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

cyanguwa and others added 3 commits January 15, 2026 06:57
Signed-off-by: Charlene Yang <[email protected]>
fix and/or logic

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Charlene Yang <[email protected]>
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@liayan
Copy link

liayan commented Jan 16, 2026

Cool, we are currently suffering from this issue.
Do we have a rough timeline for when it could be merged?
Let me know if there is anything I can do, such as a test. Would like to help.

KshitijLakhani
KshitijLakhani previously approved these changes Jan 16, 2026
Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments - some suggested changes and some questions.
Looks good to me, otherwise. Approving to not block from merge, if urgent.

It would be helpful, if you have a table for what's supported for <cuDNN9.18, >=cuDNN9.18, <sm100, sm100+, drop, dbias, etc. in the PR description.

I would also suggest to look into the number of tests being run and the timing (you can compare your PRs L0 jax and L0 pyt timings to the timings in TE 2.11 or in TE main CI - we would not want to go overboard with our timing budget, for sure. If you can report the timing in the PR, it would be helpful as well.
Worst case, if urgent, we can merge this PR and address the QA bit (which runs in the CI) in a separate PR subsequently .

Lastly, this might be some effort but would ensure correctness. As the code for skipping the tests in TE JAX tests has been modified, it would be good to check the test count before and after this PR to check if tests that should not be skipped are incorrectly being skipped

mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_deterministic.xml $TE_PATH/tests/jax/test_fused_attn.py || test_fail "tests/jax/test_fused_attn.py"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this will first run the non-deterministic fused attn tests as part of L31, which runs all non distributed tests, followed by running the fused attn deterministic tests as part of L32.
Is that the intention ? - to run fused attn 2x - with and without determinism ?

That will greatly increase our test time and might be unnecessary. The last pipeline launched was for L1 so I am unsure that I can track the effect this change will have on timing as this is an L0 change. Could you report that in the PR please ?
Thanks !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could come with an approach that runs half the fused attn tests deterministically and the other half non-deterministically ?
Or run all deterministically only ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this extra line tests test_fused_attn.py with determinism, while the line before tests everything with non-determinism. The extra test_fused_attn.py test takes ~20mins on Blackwell:

================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test_backward                                                | 5040x | 1336.28s | avg:   0.27s
================================================================================
TOTAL RUNTIME                                                |      | 1336.28s |
================================================================================

float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: To be consistent, should we call this flag is_deterministic. Similar to the first arg, is_training ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt there was a bit of distinction when I was implementing it: is_training is a description of the state we are in while deterministic is more of a request from the user (that they want to run it in deterministic mode). Not a lot of difference, to be honest - just a feel of the words. I kind of did this when I introduced deterministic as a parameter in AttentionParams so just followed along with it in this PR. Any strong objections?

@KshitijLakhani
Copy link
Collaborator

/te-ci L0 L1

@cyanguwa cyanguwa changed the title [Common] Enable determinism for cuDNN >= 9.18 on Blackwell [Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell Jan 19, 2026
@cyanguwa
Copy link
Collaborator Author

/te-ci L1

@cyanguwa
Copy link
Collaborator Author

Pipeline 42017245 for CI with updated cuDNN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants