Skip to content

Pytorch only reference for mxfp4 to avoid aiter compile on each GH actions job#114

Open
arseniivanov wants to merge 2 commits intogpu-mode:mainfrom
arseniivanov:pytorch_only_mxfp4_mm
Open

Pytorch only reference for mxfp4 to avoid aiter compile on each GH actions job#114
arseniivanov wants to merge 2 commits intogpu-mode:mainfrom
arseniivanov:pytorch_only_mxfp4_mm

Conversation

@arseniivanov
Copy link

Working testing against the aiter baseline by running it in it's own kernel.

PyTorch eager:
⚡ 499 µs 🐌 856 µs
⚡ 1151 µs 🐌 1280 µs
⚡ 510 µs 🐌 1054 µs
⚡ 482 µs 🐌 1442 µs
⚡ 1124 µs 🐌 1370 µs
⚡ 621 µs 🐌 1008 µs

Add @torch.compile to the custom_kernel(not sure if this is works in the container solution but much faster runtime):
⚡ 26.6 µs 🐌 32.0 µs
⚡ 198 µs 🐌 219 µs
⚡ 61.8 µs 🐌 69.4 µs
⚡ 56.3 µs 🐌 62.6 µs
⚡ 199 µs 🐌 216 µs
⚡ 106 µs 🐌 116 µs

Aiter baseline:
⚡ 11.1 µs 🐌 17.2 µs
⚡ 24.3 µs 🐌 31.2 µs
⚡ 11.4 µs 🐌 16.8 µs
⚡ 11.4 µs 🐌 16.5 µs
⚡ 13.4 µs 🐌 18.8 µs
⚡ 12.1 µs 🐌 16.8 µs

@msaroufim
Copy link
Member

@danielhua23

@danielhua23
Copy link
Contributor

do you have the time comparison pytorch impl vs aiter kernel + aiter compile for github action job? if latter is slow, then we can go ahead

@danielhua23
Copy link
Contributor

and are the shuffle_scale, shuffle_Weight, quantize_fp4, etc moved from aiter UT files?

@danielhua23
Copy link
Contributor

might better to do acc test like #113. Thanks

@arseniivanov
Copy link
Author

do you have the time comparison pytorch impl vs aiter kernel + aiter compile for github action job? if latter is slow, then we can go ahead

I don't know how to test this, I would have to get into this which might be beyond me given that I participate in the competition myself. I just saw Siro mentioning on the server that the compile time of Aiter reference was a big reason for the slow feedback loops and that you were planning to do a PT only ref.

and are the shuffle_scale, shuffle_Weight, quantize_fp4, etc moved from aiter UT files?

To my knowledge, the aiter UT files https://github.com/ROCm/aiter/blob/main/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py still reference the aiter helpers like shuffle_weights. I have referenced the sources I used by comments in the code, its mostly Triton code from shuffle_weights and fp4_helpers converted to torch. I think there is some magic numbers involved for the AMD hardware on concepts like subnormals so I tried to follow the Triton impl.

might better to do acc test like #113. Thanks

I don't have access to the MI355x beyond the popcorn-cli, so I can only test against the Aiter when I run the Tests against the Aiter ref in popcorn-cli, I get:

│ ✅ Passed 4/4 tests:
│✅ k: 7168; m: 8; n: 2112; seed: 124
│> Maximum error: 0.0
│✅ k: 1536; m: 16; n: 3072; seed: 6635
│> Maximum error: 0.0
│✅ k: 1536; m: 64; n: 3072; seed: 45
│> Maximum error: 0.0
│✅ k: 512; m: 256; n: 2880; seed: 78
│> Maximum error: 0.0

I think it's bitwise identical due to following the Triton impl.

@danielhua23
Copy link
Contributor

for acc test, I mean you can test following the idea from #113 like submitting you torch codes and see the accuracy comparison with aiter ref

@arseniivanov
Copy link
Author

for acc test, I mean you can test following the idea from #113 like submitting you torch codes and see the accuracy comparison with aiter ref

Yes I have done this, it's at the bottom of the previous comment

@danielhua23
Copy link
Contributor

for acc test, I mean you can test following the idea from #113 like submitting you torch codes and see the accuracy comparison with aiter ref

Yes I have done this, it's at the bottom of the previous comment

wow, you mean torch impl and aiter impl are exactly the same result?

@arseniivanov
Copy link
Author

arseniivanov commented Mar 8, 2026

for acc test, I mean you can test following the idea from #113 like submitting you torch codes and see the accuracy comparison with aiter ref

Yes I have done this, it's at the bottom of the previous comment

wow, you mean torch impl and aiter impl are exactly the same result?

I do believe so. You can try yourself, here is the kernel in a format you can submit.

#!POPCORN leaderboard amd-mxfp4-mm
#!POPCORN gpu MI355X
import torch
from task import input_t, output_t

def e8m0_to_f32(scale_e8m0_biased: torch.Tensor) -> torch.Tensor:
    scale_e8m0_biased = scale_e8m0_biased.contiguous().view(torch.uint8)
    zero_case = scale_e8m0_biased == 0
    nan_case = scale_e8m0_biased == 0xFF
    
    scale_f32 = scale_e8m0_biased.to(torch.int32) << 23
    scale_f32[zero_case] = 0x00400000
    scale_f32[nan_case] = 0x7F800001
    return scale_f32.view(torch.float32)

def mxfp4_to_f32(x: torch.Tensor) -> torch.Tensor:
    x = x.contiguous().view(torch.uint8)
    x = x.repeat_interleave(2, dim=-1)
    x[..., ::2] = x[..., ::2] & 0xF
    x[..., 1::2] = x[..., 1::2] >> 4
    
    mxfp4_list = [
        0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
        -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
    ]
    mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device=x.device)
    return mxfp4_in_f32[x.long()]

def quantize_mxfp4(x: torch.Tensor):
    M, N = x.shape
    x_blocks = x.view(M, N // 32, 32).float()

    # Calculate scale natively (matching Triton bitwise rounding)
    amax = x_blocks.abs().max(dim=-1, keepdim=True).values
    amax_i32 = amax.view(torch.int32)
    amax_i32 = (amax_i32 + 2097152) & ~8388607
    amax_rounded = amax_i32.view(torch.float32)

    log_amax = torch.log2(amax_rounded)
    scale_e8m0_unbiased = torch.floor(log_amax) - 2
    scale_e8m0_unbiased = torch.nan_to_num(scale_e8m0_unbiased, neginf=-127.0)
    scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -127, 127)

    bs_e8m0 = (scale_e8m0_unbiased + 127).to(torch.uint8)
    quant_scale = torch.exp2(-scale_e8m0_unbiased)
    
    qx = x_blocks * quant_scale
    qx_i32 = qx.view(torch.int32)
    
    e = (qx_i32 >> 23) & 0xFF
    m = qx_i32 & 0x7FFFFF
    s_bit = (qx < 0).to(torch.uint8) << 3 

    E8_BIAS, E2_BIAS = 127, 1
    adjusted_exponents = E8_BIAS - (e + 1)
    shift_amount = torch.clamp(adjusted_exponents, min=0, max=31)
    
    denorm_m = (4194304 | (m >> 1)) >> shift_amount
    m = torch.where(e < E8_BIAS, denorm_m, m)

    e = torch.maximum(e, torch.tensor(E8_BIAS - E2_BIAS)) - (E8_BIAS - E2_BIAS)

    e2m1_tmp = torch.minimum((((e << 2) | (m >> 21)) + 1) >> 1, torch.tensor(0x7))
    e2m1_value = s_bit | e2m1_tmp.to(torch.uint8)

    e2m1_value = e2m1_value.view(M, N)
    evens = e2m1_value[:, 0::2]
    odds = e2m1_value[:, 1::2]
    x_fp4 = evens | (odds << 4)

    return x_fp4, bs_e8m0.view(M, N // 32)

@torch.compile
def custom_kernel(data: input_t) -> output_t:
    A, B, *_ = data

    A_q, A_scale = quantize_mxfp4(A)
    B_q, B_scale = quantize_mxfp4(B)

    M, _ = A.shape
    N, _ = B.shape

    x_f32 = mxfp4_to_f32(A_q)
    x_scales = A_scale.repeat_interleave(32, dim=1)
    x_f32 = x_f32 * e8m0_to_f32(x_scales)

    w_f32 = mxfp4_to_f32(B_q)
    w_scales = B_scale.repeat_interleave(32, dim=1)
    w_f32 = w_f32 * e8m0_to_f32(w_scales)

    return torch.mm(x_f32, w_f32.T).to(A.dtype)[:M, :N]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants