Skip to content

[C] NVFP4 quantization for GroupedTensor#2655

Merged
ksivaman merged 4 commits intoNVIDIA:mainfrom
ksivaman:nvfp4_grouped_quantize
Feb 11, 2026
Merged

[C] NVFP4 quantization for GroupedTensor#2655
ksivaman merged 4 commits intoNVIDIA:mainfrom
ksivaman:nvfp4_grouped_quantize

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Pieces taken from #2600.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • NVFP4 quantization for grouped tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Zhongbo Zhu <[email protected]>
Co-authored-by: Zhongbo Zhu <[email protected]>
@ksivaman ksivaman added the MoE label Feb 6, 2026
@ksivaman ksivaman marked this pull request as draft February 6, 2026 06:38
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds NVFP4 quantization support for GroupedTensor with graph-safe APIs that can handle device-managed tensor grouping. The changes introduce two new CUDA implementations and corresponding C API functions for grouped Hadamard transforms with quantization.

Major changes:

  • Added nvte_group_hadamard_transform_amax_graph_safe() for computing amax with Hadamard transforms on grouped tensors
  • Added nvte_group_hadamard_transform_cast_fusion_graph_safe() for NVFP4 quantization with row/column Hadamard transforms
  • Added nvte_group_amax_graph_safe() for grouped amax without Hadamard transforms
  • New implementations use TMA (Tensor Memory Accelerator) kernels and CUDA graphs for efficiency

Issues found:

  • Multiple TODO comments in graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu requesting verification of group index calculations and tensor layout logic (lines 709, 724, 778, 795)
  • Missing input sanity checks noted in TODO at line 1352
  • No tests found for the new graph-safe APIs
  • PR checklist indicates functionality is incomplete and tests are missing

Confidence Score: 2/5

  • This PR contains incomplete functionality with unverified critical logic that needs thorough review before merging
  • The PR author has explicitly marked functionality as incomplete in the checklist. Multiple TODO comments request verification of critical group index calculations and tensor layout logic. Missing input validation and no test coverage for the new APIs. The core quantization kernel contains complex tensor manipulations that haven't been verified.
  • Primary attention needed on graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu - verify group index calculations, tensor layout logic, and add input validation before merging

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/hadamard_transform.h Added two new C API declarations for graph-safe grouped Hadamard operations with proper documentation
transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu Implements graph-safe grouped Hadamard transform with amax computation using TMA kernels; currently limited to constant last dimension
transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu Implements NVFP4 quantization with Hadamard transforms for grouped tensors; contains multiple TODO comments requesting verification of critical logic, missing input validation, and complex tensor layout calculations

Sequence Diagram

sequenceDiagram
    participant Client
    participant API as C API Layer
    participant Kernel as CUDA Kernel
    participant GPU as GPU Memory

    Client->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
    API->>API: Convert NVTEGroupedTensor to GroupedTensor
    API->>API: Convert NVTEQuantizationConfig to QuantizationConfig
    API->>API: Validate quant_workspace (>=4 bytes)
    
    alt Stochastic Rounding Enabled
        API->>API: Validate RNG state tensor
    end
    
    API->>Kernel: Launch group_hadamard_transform_cast_fusion_graph_safe()
    Kernel->>GPU: TMA load input tensors
    Kernel->>Kernel: Determine tensor ID from offset (binary search)
    Kernel->>Kernel: Apply row-wise quantization
    Kernel->>Kernel: Apply Hadamard transform
    Kernel->>Kernel: Apply column-wise quantization to NVFP4
    Kernel->>Kernel: Compute scaling factors (FP8 E4M3)
    Kernel->>GPU: Write quantized output + scaling factors
    Kernel-->>API: Return
    API-->>Client: Return
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +709 to +724
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);

// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging

Suggested change
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
// Determine the current tensor group index based on tile offset
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(
reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)),
make_shape(M, cur_N), DStride{}); // (M,packed_N)

Comment on lines +778 to +795
// TODO(zhongbo): double check the logic here
int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors,
global_tile_n_offset * M, packed_N, M, offsets);

if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
c_global_amax_val = shared_storage.global_d_amax[group_idx];
// update amax
global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
// TODO(zhongbo): double check the logic here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

using transformer_engine::detail::ShapeRepresentation;

void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr);
// TODO(zhongbo): add input sanity checks here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add input sanity checks as noted in TODO

Signed-off-by: Zhongbo Zhu <[email protected]>
@ksivaman
Copy link
Member Author

ksivaman commented Feb 9, 2026

Fixes #2510

@ksivaman ksivaman marked this pull request as ready for review February 11, 2026 06:23
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1500 to +1512
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}

if (input_tensor->num_tensors == 0) {
return;
}

// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion_graph_safe(
input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
*quant_workspace_tensor, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uninitialized quant_config used

QuantizationConfig quant_config_cpp; is left uninitialized when quant_config == nullptr, but it’s still passed into group_hadamard_transform_cast_fusion_graph_safe(...) (line 1511). That’s undefined behavior and can lead to garbage config values being read inside the kernel launch path. Consider default-initializing (QuantizationConfig quant_config_cpp{};) or returning an error when quant_config is required.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

transformer_engine/common/CMakeLists.txt
Missing CUTLASS debug flags

hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu includes CUTLASS/CUTE headers, but it’s not listed in CUTLASS_KERNEL_SOURCES, so it won’t get the -g0;-dopt=on compile options that this file group relies on to avoid debug-build hangs. Add this new source to CUTLASS_KERNEL_SOURCES so it is compiled with the same options as the other CUTLASS kernels.

@ksivaman
Copy link
Member Author

/te-ci

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@zhongbozhu zhongbozhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditional approve. More changes to come for #2600

@ksivaman ksivaman merged commit 402ea54 into NVIDIA:main Feb 11, 2026
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants