tgoop.com/quant_prune_distill/54
Last Update:
Flash Attention-2: Faster Attention with Better Parallelism and Work Partitioning
[Статья][Код]
На днях Tri Dao (это имя и фамилия, а не три пути из китайской философии) выпустил сиквел знаменитого блокбастера Flash Attention - Flash Attention 2. Обновленная версия Flash Attention примерно в два раза быстрее своей предшественницы и местами до 10 раз бывает быстрее наивной реализации PyTorch.
Введение
Потребность в работе с длинным контекстом возникает в ряде приложений - задачах связанных с написанием или пониманием книг, обработкой видео, аудио. Наивная реализация Attention требует квадратичного по длине последовательности объема памяти и вычислений, из-за чего на практике редко используют контекст более 1k токенов. В свое время много работ и усилий исследователей было потрачено на разработку альтернатив алгоритму внимания в его исходной формулировке, но нельзя сказать, чтобы один из множества подходов оказался конкурентоспособен на широком наборе задач.
В работе по Flash Attention добились значительного ускорения Attention и снижения пикового расхода памяти за счет оптимизации входящих в него операции. При этом сам алгоритм математически эквивалентен (up to numerical precision) исходному Attention.
Напомню, что в основе оригинальной работы по Flash Attention лежат следующие наблюдения:
1️⃣ Большую часть времени занимают не вычисления, а работа с памятью.
2️⃣ Для подсчета Attention не обязательно материализовывать всю матрицу Attention, квадратичную по длине последовательности целиком. Можно считать ее поблочно, а затем агрегировать результат.
Flash Attention уменьшает количество операций по перекачке данных с HBM памяти в кэши GPU за счет выполнения нескольких математических операций сразу в пределах одного блока (kernel fusion). При обратном проходе и подсчете градиентов по параметрам матрица Attention пересчитывается снова поблочно. Flash Attention делает больше вычислений, чем исходный алгоритм, но так как вычисления в разы быстрее перекачки памяти туда-сюда, получается значительный выигрыш в производительности.
Однако, даже оптимизированный Flash Attention сильно недоиспользует возможности современных ускорителей вычислений, достигая всего-лишь 25-40% от теоретической максимальной производительности, в то время как матричные умножения могут достигать 80-90%.
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/54