Skip to content

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 31, 2026

Description

Problem

Using Float8BlockQuantizer with sequence parallel fails with AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizer when local tensor dimensions aren't divisible by 128.

Solution

Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather

###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.

Verification

The code change does not alter convergence behavior
image

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.
image

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Chen Cui <chcui@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 31, 2026

Greptile Overview

Greptile Summary

This PR removes a pre-check that required sequence-parallel all-gather inputs to be quantizable, and instead relies on distributed.gather_along_first_dim to fall back to a high-precision all-gather when quantization constraints (e.g., FP8 block scaling divisibility) aren’t met. It also adjusts the FP8-blockwise/NVFP4/MXFP8 gather implementations so that if the fallback path receives an already-quantized tensor storage, it is dequantized before performing the high-precision collective and then requantized.

The main functional changes are concentrated in transformer_engine/pytorch/distributed.py, while linear.py, layernorm_linear.py, and layernorm_mlp.py drop the now-removed assert_dim_for_all_gather call. utils.py removes the helper entirely.

Confidence Score: 3/5

  • This PR is close to mergeable but has a likely runtime issue in the new high-precision all-gather fallbacks that should be fixed first.
  • Core logic change (removing the quantizability assertion and improving fallback handling for already-quantized inputs) is reasonable, but the new fallback branches call all_gather_into_tensor without ensuring contiguous input, unlike other call sites in the same module. That can lead to runtime failures when the fallback triggers on non-contiguous views.
  • transformer_engine/pytorch/distributed.py

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds high-precision fallbacks for FP8-blockwise/NVFP4/MXFP8 all-gather and dequantizes already-quantized inputs before fallback; but the new fallback paths pass inp to all_gather_into_tensor without making it contiguous, unlike other call sites in this module.
transformer_engine/pytorch/module/layernorm_linear.py Removes the assert_dim_for_all_gather check from FP8 paths; relies on gather_along_first_dim fallback behavior instead. No new correctness issues spotted in this file.
transformer_engine/pytorch/module/layernorm_mlp.py Removes the assert_dim_for_all_gather check from the FP8+sequence-parallel path; behavior now depends on distributed.gather_along_first_dim fallbacks. No additional issues found in this file.
transformer_engine/pytorch/module/linear.py Removes the assert_dim_for_all_gather check from FP8 input all-gather path; otherwise unchanged. No new issues found here.
transformer_engine/pytorch/utils.py Deletes assert_dim_for_all_gather and its Quantizer import; no remaining references. Change is consistent with moving reliance to gather_along_first_dim fallback behavior.

Sequence Diagram

sequenceDiagram
    participant M as Module (Linear/LNLinear/LNMLP)
    participant Q as *Quantizer
    participant D as distributed.gather_along_first_dim
    participant AG as torch.distributed.all_gather_into_tensor

    M->>Q: (optional) quantize local activation
    M->>D: gather_along_first_dim(inp, tp_group, quantizer)

    alt Float8BlockQuantizer and quantizable
        D->>D: _start_all_gather_fp8_blockwise
        D->>AG: all_gather_into_tensor(FP8 blocks + scales)
        D-->>M: Float8BlockwiseQTensorStorage (+ async handle)
    else Not quantizable / unsupported dims
        D->>D: high-precision fallback
        D->>Q: (if QuantizedTensorStorage) dequantize(dtype)
        D->>AG: all_gather_into_tensor(out_fp, inp_fp)
        D->>Q: quantize gathered output
        D-->>M: Quantized output tensor
    end
Loading

greptile-apps[bot]

This comment was marked as outdated.

@cyanguwa cyanguwa requested a review from timmoon10 February 2, 2026 18:48
@timmoon10

This comment was marked as outdated.

timmoon10

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

@timmoon10 timmoon10 self-requested a review February 2, 2026 19:38
Perform all-gather in high-precision if the input tensor is too small to quantize.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10
timmoon10 previously approved these changes Feb 2, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.

@timmoon10

This comment was marked as outdated.

@timmoon10
Copy link
Collaborator

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

out = torch.empty(out_shape, dtype=dtype, device=device)
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member

ptrendx commented Feb 10, 2026

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 1102 to 1108
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-contiguous gather input

In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants