[C] NVFP4 quantization for GroupedTensor#2655
Conversation
Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Zhongbo Zhu <[email protected]> Co-authored-by: Zhongbo Zhu <[email protected]>
Greptile OverviewGreptile SummaryThis PR adds NVFP4 quantization support for Major changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant API as C API Layer
participant Kernel as CUDA Kernel
participant GPU as GPU Memory
Client->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
API->>API: Convert NVTEGroupedTensor to GroupedTensor
API->>API: Convert NVTEQuantizationConfig to QuantizationConfig
API->>API: Validate quant_workspace (>=4 bytes)
alt Stochastic Rounding Enabled
API->>API: Validate RNG state tensor
end
API->>Kernel: Launch group_hadamard_transform_cast_fusion_graph_safe()
Kernel->>GPU: TMA load input tensors
Kernel->>Kernel: Determine tensor ID from offset (binary search)
Kernel->>Kernel: Apply row-wise quantization
Kernel->>Kernel: Apply Hadamard transform
Kernel->>Kernel: Apply column-wise quantization to NVFP4
Kernel->>Kernel: Compute scaling factors (FP8 E4M3)
Kernel->>GPU: Write quantized output + scaling factors
Kernel-->>API: Return
API-->>Client: Return
|
| // TODO(zhongbo): double check the logic here | ||
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | ||
| packed_N, M, offsets); | ||
|
|
||
| // Determine quantization scale factor layouts/output splits for this group | ||
| TSFDLayout sfd_layout; | ||
| int cur_N = static_cast<int>(first_dims[group_idx]); | ||
| if constexpr (kEnableSwizzleSFOutput) { | ||
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | ||
| } else { | ||
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | ||
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | ||
| } | ||
| // Build output tensors for columns and their quant scales | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging
| // TODO(zhongbo): double check the logic here | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| // TODO(zhongbo): double check the logic here | |
| // Determine the current tensor group index based on tile offset | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>( | |
| reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)), | |
| make_shape(M, cur_N), DStride{}); // (M,packed_N) |
| // TODO(zhongbo): double check the logic here | ||
| int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| global_tile_n_offset * M, packed_N, M, offsets); | ||
|
|
||
| if (cur_group_idx != group_idx) { | ||
| group_idx = cur_group_idx; | ||
| c_global_amax_val = shared_storage.global_d_amax[group_idx]; | ||
| // update amax | ||
| global_encode_scale = c_global_amax_val > 0.0f | ||
| ? cutlass::minimum_with_nan_propagation<float>{}( | ||
| (fp8_max * fp4_max) / c_global_amax_val, | ||
| cutlass::platform::numeric_limits<float>::max()) | ||
| : 1.0f; | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| if constexpr (kUseFastMath) { | ||
| global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; | ||
| } | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| using transformer_engine::detail::ShapeRepresentation; | ||
|
|
||
| void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr); | ||
| // TODO(zhongbo): add input sanity checks here |
There was a problem hiding this comment.
add input sanity checks as noted in TODO
Signed-off-by: Zhongbo Zhu <[email protected]>
|
Fixes #2510 |
| QuantizationConfig quant_config_cpp; | ||
| if (quant_config != nullptr) { | ||
| quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config); | ||
| } | ||
|
|
||
| if (input_tensor->num_tensors == 0) { | ||
| return; | ||
| } | ||
|
|
||
| // Call the multi-tensor Hadamard transform amax implementation. | ||
| group_hadamard_transform_cast_fusion_graph_safe( | ||
| input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, | ||
| *quant_workspace_tensor, stream); |
There was a problem hiding this comment.
Uninitialized quant_config used
QuantizationConfig quant_config_cpp; is left uninitialized when quant_config == nullptr, but it’s still passed into group_hadamard_transform_cast_fusion_graph_safe(...) (line 1511). That’s undefined behavior and can lead to garbage config values being read inside the kernel launch path. Consider default-initializing (QuantizationConfig quant_config_cpp{};) or returning an error when quant_config is required.
Additional Comments (1)
|
|
/te-ci |
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
zhongbozhu
left a comment
There was a problem hiding this comment.
Conditional approve. More changes to come for #2600
Description
Pieces taken from #2600.
Type of change
Changes
Checklist: