Pytorch only reference for mxfp4 to avoid aiter compile on each GH actions job#114
Pytorch only reference for mxfp4 to avoid aiter compile on each GH actions job#114arseniivanov wants to merge 2 commits intogpu-mode:mainfrom
Conversation
…new github actions job
|
do you have the time comparison pytorch impl vs aiter kernel + aiter compile for github action job? if latter is slow, then we can go ahead |
|
and are the shuffle_scale, shuffle_Weight, quantize_fp4, etc moved from aiter UT files? |
|
might better to do acc test like #113. Thanks |
I don't know how to test this, I would have to get into this which might be beyond me given that I participate in the competition myself. I just saw Siro mentioning on the server that the compile time of Aiter reference was a big reason for the slow feedback loops and that you were planning to do a PT only ref.
To my knowledge, the aiter UT files https://github.com/ROCm/aiter/blob/main/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py still reference the aiter helpers like shuffle_weights. I have referenced the sources I used by comments in the code, its mostly Triton code from shuffle_weights and fp4_helpers converted to torch. I think there is some magic numbers involved for the AMD hardware on concepts like subnormals so I tried to follow the Triton impl.
I don't have access to the MI355x beyond the popcorn-cli, so I can only test against the Aiter when I run the Tests against the Aiter ref in popcorn-cli, I get: │ ✅ Passed 4/4 tests: I think it's bitwise identical due to following the Triton impl. |
|
for acc test, I mean you can test following the idea from #113 like submitting you torch codes and see the accuracy comparison with aiter ref |
Yes I have done this, it's at the bottom of the previous comment |
wow, you mean torch impl and aiter impl are exactly the same result? |
I do believe so. You can try yourself, here is the kernel in a format you can submit. #!POPCORN leaderboard amd-mxfp4-mm
#!POPCORN gpu MI355X
import torch
from task import input_t, output_t
def e8m0_to_f32(scale_e8m0_biased: torch.Tensor) -> torch.Tensor:
scale_e8m0_biased = scale_e8m0_biased.contiguous().view(torch.uint8)
zero_case = scale_e8m0_biased == 0
nan_case = scale_e8m0_biased == 0xFF
scale_f32 = scale_e8m0_biased.to(torch.int32) << 23
scale_f32[zero_case] = 0x00400000
scale_f32[nan_case] = 0x7F800001
return scale_f32.view(torch.float32)
def mxfp4_to_f32(x: torch.Tensor) -> torch.Tensor:
x = x.contiguous().view(torch.uint8)
x = x.repeat_interleave(2, dim=-1)
x[..., ::2] = x[..., ::2] & 0xF
x[..., 1::2] = x[..., 1::2] >> 4
mxfp4_list = [
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
]
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device=x.device)
return mxfp4_in_f32[x.long()]
def quantize_mxfp4(x: torch.Tensor):
M, N = x.shape
x_blocks = x.view(M, N // 32, 32).float()
# Calculate scale natively (matching Triton bitwise rounding)
amax = x_blocks.abs().max(dim=-1, keepdim=True).values
amax_i32 = amax.view(torch.int32)
amax_i32 = (amax_i32 + 2097152) & ~8388607
amax_rounded = amax_i32.view(torch.float32)
log_amax = torch.log2(amax_rounded)
scale_e8m0_unbiased = torch.floor(log_amax) - 2
scale_e8m0_unbiased = torch.nan_to_num(scale_e8m0_unbiased, neginf=-127.0)
scale_e8m0_unbiased = torch.clamp(scale_e8m0_unbiased, -127, 127)
bs_e8m0 = (scale_e8m0_unbiased + 127).to(torch.uint8)
quant_scale = torch.exp2(-scale_e8m0_unbiased)
qx = x_blocks * quant_scale
qx_i32 = qx.view(torch.int32)
e = (qx_i32 >> 23) & 0xFF
m = qx_i32 & 0x7FFFFF
s_bit = (qx < 0).to(torch.uint8) << 3
E8_BIAS, E2_BIAS = 127, 1
adjusted_exponents = E8_BIAS - (e + 1)
shift_amount = torch.clamp(adjusted_exponents, min=0, max=31)
denorm_m = (4194304 | (m >> 1)) >> shift_amount
m = torch.where(e < E8_BIAS, denorm_m, m)
e = torch.maximum(e, torch.tensor(E8_BIAS - E2_BIAS)) - (E8_BIAS - E2_BIAS)
e2m1_tmp = torch.minimum((((e << 2) | (m >> 21)) + 1) >> 1, torch.tensor(0x7))
e2m1_value = s_bit | e2m1_tmp.to(torch.uint8)
e2m1_value = e2m1_value.view(M, N)
evens = e2m1_value[:, 0::2]
odds = e2m1_value[:, 1::2]
x_fp4 = evens | (odds << 4)
return x_fp4, bs_e8m0.view(M, N // 32)
@torch.compile
def custom_kernel(data: input_t) -> output_t:
A, B, *_ = data
A_q, A_scale = quantize_mxfp4(A)
B_q, B_scale = quantize_mxfp4(B)
M, _ = A.shape
N, _ = B.shape
x_f32 = mxfp4_to_f32(A_q)
x_scales = A_scale.repeat_interleave(32, dim=1)
x_f32 = x_f32 * e8m0_to_f32(x_scales)
w_f32 = mxfp4_to_f32(B_q)
w_scales = B_scale.repeat_interleave(32, dim=1)
w_f32 = w_f32 * e8m0_to_f32(w_scales)
return torch.mm(x_f32, w_f32.T).to(A.dtype)[:M, :N] |
Working testing against the aiter baseline by running it in it's own kernel.
PyTorch eager:
⚡ 499 µs 🐌 856 µs
⚡ 1151 µs 🐌 1280 µs
⚡ 510 µs 🐌 1054 µs
⚡ 482 µs 🐌 1442 µs
⚡ 1124 µs 🐌 1370 µs
⚡ 621 µs 🐌 1008 µs
Add @torch.compile to the custom_kernel(not sure if this is works in the container solution but much faster runtime):
⚡ 26.6 µs 🐌 32.0 µs
⚡ 198 µs 🐌 219 µs
⚡ 61.8 µs 🐌 69.4 µs
⚡ 56.3 µs 🐌 62.6 µs
⚡ 199 µs 🐌 216 µs
⚡ 106 µs 🐌 116 µs
Aiter baseline:
⚡ 11.1 µs 🐌 17.2 µs
⚡ 24.3 µs 🐌 31.2 µs
⚡ 11.4 µs 🐌 16.8 µs
⚡ 11.4 µs 🐌 16.5 µs
⚡ 13.4 µs 🐌 18.8 µs
⚡ 12.1 µs 🐌 16.8 µs