Flash attention in practice 🔥
PyTorch 2.0 has flash-attention built-in, here's how you can use it:
1. Replace your attention op with
1. Use 16-bit float (which you should always be using for training anyway)
1. Make sure that your head dim is a multiple of 8 and no more than 128
Lookup git diff above as an example.
Result:
1. 2010 examples/sec ⟼ 2790 examples/sec. 40% speedup (8x4090 setup)
1. RAM: 22Gb ⟼ 16 GB reduction at 256 sequence length
1. Absolutely the same model, no approximations
(In my case a big chunk of improvement also came at the cost of reducing softmax precision from fp32 to bf16, but to hell with that)
Flash attention should yield even higher improvements on larger sequence lengths.
PyTorch 2.0 has flash-attention built-in, here's how you can use it:
1. Replace your attention op with
torch.nn.functional.scaled_dot_product_attention
1. Use 16-bit float (which you should always be using for training anyway)
1. Make sure that your head dim is a multiple of 8 and no more than 128
Lookup git diff above as an example.
Result:
1. 2010 examples/sec ⟼ 2790 examples/sec. 40% speedup (8x4090 setup)
1. RAM: 22Gb ⟼ 16 GB reduction at 256 sequence length
1. Absolutely the same model, no approximations
(In my case a big chunk of improvement also came at the cost of reducing softmax precision from fp32 to bf16, but to hell with that)
Flash attention should yield even higher improvements on larger sequence lengths.
🔥49👍9❤3🤯2
tgoop.com/dlinnlp/1605
Create:
Last Update:
Last Update:
Flash attention in practice 🔥
PyTorch 2.0 has flash-attention built-in, here's how you can use it:
1. Replace your attention op with
1. Use 16-bit float (which you should always be using for training anyway)
1. Make sure that your head dim is a multiple of 8 and no more than 128
Lookup git diff above as an example.
Result:
1. 2010 examples/sec ⟼ 2790 examples/sec. 40% speedup (8x4090 setup)
1. RAM: 22Gb ⟼ 16 GB reduction at 256 sequence length
1. Absolutely the same model, no approximations
(In my case a big chunk of improvement also came at the cost of reducing softmax precision from fp32 to bf16, but to hell with that)
Flash attention should yield even higher improvements on larger sequence lengths.
PyTorch 2.0 has flash-attention built-in, here's how you can use it:
1. Replace your attention op with
torch.nn.functional.scaled_dot_product_attention
1. Use 16-bit float (which you should always be using for training anyway)
1. Make sure that your head dim is a multiple of 8 and no more than 128
Lookup git diff above as an example.
Result:
1. 2010 examples/sec ⟼ 2790 examples/sec. 40% speedup (8x4090 setup)
1. RAM: 22Gb ⟼ 16 GB reduction at 256 sequence length
1. Absolutely the same model, no approximations
(In my case a big chunk of improvement also came at the cost of reducing softmax precision from fp32 to bf16, but to hell with that)
Flash attention should yield even higher improvements on larger sequence lengths.
BY DL in NLP


Share with your friend now:
tgoop.com/dlinnlp/1605