[SP] add SP deny list instead of allow#7887
[SP] add SP deny list instead of allow#7887kashif wants to merge 12 commits intodeepspeedai:masterfrom
Conversation
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
tohtana
left a comment
There was a problem hiding this comment.
Hi @kashif,
Thank you for opening this PR! I think supporting HF hub kernels is is a significant update.
Regarding the approach, we check if core_attn_implementation is in ALL_ATTENTION_FUNCTIONS but HF hub kernels like kernels-community/flash-attn2 is not in the list. So HF hub kernels won’t still be available with this fix.
We probably need to do the proper registration steps:
- Reject known-bad impls explicitly: eager, flex_attention, and probably paged|eager.
- If
core_attn_implementationis an HF hub kernel string, call the HF registration path first. (Usinglazy_import_flash_attention(…)) - Then read
core_attn_function = ALL_ATTENTION_FUNCTIONS[core_attn_implementation]. - Build
uattnfrom that original function. - Replace that key with
uattn_wrapper.
Does it make sense to you?
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
thanks @tohtana I have tried to fix all the issues raised, if you can kindly check again? |
We actually don't know if flex_attention is bad, we just haven't tried it out. Do you have resources to try it out, Kashif? Same for the others on the list. That's why we started with approve list, rather than deny. The only reason eager is denied is that it requires 4D attention_mask which is a bad idea for long sequence. BTW, SDPA is silently broken with packed samples - when there is no attn mask, it ignores pos ids and attends to the whole sequence instead. Expect bad results. Not sure how to flag that to users - probably need to inspect pos ids and see if they reset at least once and disallow sdpa then. |
|
Hi @kashif, I also think Stas's comment makes sense. Can you try implementing such a validation? |
|
sure @tohtana i can check |
|
to make things more exact - it's packed samples + pos ids + 4D |
|
oh, Kashif, I'm being told |
|
I ran some experiments comparing flash_attention_2, sdpa, and flex_attention with SP=4 on Qwen3-4B (GQA: 32 Q Without SP (1 GPU baseline): flash_attention_2 and sdpa produce identical losses — confirming the backends are With SP=4 (4 GPUs): sdpa and flex_attention match each other, but both diverge significantly from @stas00 any ideas on what flash_attention_2 might be doing differently after the all-to-all that |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
ok @stas00 I now enerate position_ids if missing from batch, build causal BlockMask for flex_attention and do a one-time packed sample validation for packed samples + sdpa/eager Now the outputs are matching: |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
Thank you for running those quality comparison experiments, Kashif I'm a bit unclear about your last "success" comment - what was missing to make FA2 match? are you saying the mismatch was from missing position_ids? but we said that already that SDPA (and now most likely FlexAttenion) have a trouble with no-attn-mask / yes-pos-id and will ignore packed samples. SDPA on the other hand does the right thing here. And it's great to hear Flex Attention works as well with Ulysses, so we could add it to the allow list. |
| if has_packed_samples and self.core_attn_implementation in ("sdpa", "eager"): | ||
| raise ValueError( |
There was a problem hiding this comment.
heh, I thought we were discussing that it's HF Transformers that has to do that, not Ulysses SP. It affects all users regardless of whether they use Ulysses or not. Unless HF Transformers disallows not providing attn-mask with sdpa/eager, which I don't think is the case.
There was a problem hiding this comment.
agree, removed from DeepSpeed side
| # looks like packed sequences [0,...,N, 0,...,N, ...]. flash_attention_2 handles | ||
| # this via flash_varlen_fn, but sdpa/flex_attention apply full causal masking | ||
| # across the resets, producing incorrect attention. | ||
| if "position_ids" not in batch: |
There was a problem hiding this comment.
I'm not sure about this. This might lead to a user getting the wrong behavior if they packed samples but forgot to supply pos ids. Should we simply assert if pos ids aren't there and not potentially create invalid pos ids?
I agree there needs to be a check and it's not there.
There was a problem hiding this comment.
yes, It would need to be in the TRL trainer, for the collator to always provide position_ids when SP is enabled, so the adapter never needs to generate them. I Can try to fix it there.
There was a problem hiding this comment.
Thank you, Kashif.
And probably then add an assert on SP side if pos id isn't there?
So, FA2 was the one producing correct results, while SDPA/flex were wrong. Here's what was happening: When FA2 "accidentally" handles this correctly — SDPA with The fix: generate With this fix, all three backends match within numerical precision: For |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
great explanations, Kashif - thank you!
Thank you, Kashif |
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
@stas00, regarding point 2, we added
|
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
|
Thank you very much, Kashif. Do you think all this amazing tooling you added should live here and not in HF Transformers? |
|
checking |
|
So some SP-specific things tied to the all-to-all make sense to be here...
Agree that in Transformers:
On the TRL side:
|
|
Thank you for the detailed summary, Kashif. I agree with everything, except:
I think it should assert. Warnings don't work and allowing invalid training can be so so costly to the user who missed the warning in the sea of warnings. I wonder how many people will discover their model has been mistrained and they had no clue that was the case, other than getting bad outcomes. |
|
Please let us know when things are ready for the final review, Kashif. |
this way one can register kernels based flash-attn as well with SP