Метод
В работе Flash Attention 2 по существу еще слегка подкрутили процедуру вычисления и повысили степень параллелизма самой операций.
Алгоритм вычисления
Автор заметил, что операции, не являющиеся матричным умножением, выполняются куда медленнее (в 16 раз), чем матричные умножения, потому переписал алгоритм так, чтобы уменьшить их количество. Казалось бы, их количество невелико, но тем не менее, они занимают существенную часть общего времени работы. Кроме того, при авторегрессионной генерации нужна лишь верхнетреугольная часть матрицы Attention, и вместо того, чтобы считать ее, а затем занулять, ее просто не считают. Вот так вот!
Благодаря перечисленным выше нововведениям удается добиться ускорения 2-3x.
Параллелизм
Flash Attention-1 параллелизует вычисления по размеру батча и числу голов в трансформере, но если батч не слишком большой или трансформер не очень огромный, то многие streaming multiprocessors (SM) простаивают. И чтобы не оставлять их без дела, предлагается паралеллизовывать вычисления и по длине последовательности. На прямом проходе ряды матрицы Attention можно считать независимо, а на обратном проходе - колонки. И каждый поток обрабатывает свой токен. Кроме того, для уменьшения коммуникации между варпами (группами потоков), оказывается целесообразным держать куски матриц ключей (Key) и значений (Values) общими для групп поток, а Query свою на варп (в Flash Attention-1 было наоборот). Уменьшение количество операций чтения/записи приводит к дополнительному ускорению.
Результаты
Flash-Attention-2 сравнивается с Flash-Attention из оригинального репозитория, реализации на triton и xformers. Для замеров рассматривают последовательности длиной от 512 до 16k токенов, и слой attention со скрытой размерностью 2048 (64 или 128 голов).
FlashAttention-2 в 1.3-1.5x быстрее на прямом проходе, и до 2x быстрее на обратном проходе по сравнению с Flash-Attention - 1 (особенно велик выигрыш при использовании causal mask). Flash-Attention - 2 использует до 72% теоретической производительности A100. На H100 разница еще заметнее.
Выводы
Данная история поучительна тем, что одна и та же математическая операция в зависимости от реализации, может выполняться принципиально разное время. Замечательный пример того, что насколько учет особенностей железа, время работы различных компонент, сильных и слабых сторон ускорителя вычислений важен при проектировании алгоритмов.
В работе Flash Attention 2 по существу еще слегка подкрутили процедуру вычисления и повысили степень параллелизма самой операций.
Алгоритм вычисления
Автор заметил, что операции, не являющиеся матричным умножением, выполняются куда медленнее (в 16 раз), чем матричные умножения, потому переписал алгоритм так, чтобы уменьшить их количество. Казалось бы, их количество невелико, но тем не менее, они занимают существенную часть общего времени работы. Кроме того, при авторегрессионной генерации нужна лишь верхнетреугольная часть матрицы Attention, и вместо того, чтобы считать ее, а затем занулять, ее просто не считают. Вот так вот!
Благодаря перечисленным выше нововведениям удается добиться ускорения 2-3x.
Параллелизм
Flash Attention-1 параллелизует вычисления по размеру батча и числу голов в трансформере, но если батч не слишком большой или трансформер не очень огромный, то многие streaming multiprocessors (SM) простаивают. И чтобы не оставлять их без дела, предлагается паралеллизовывать вычисления и по длине последовательности. На прямом проходе ряды матрицы Attention можно считать независимо, а на обратном проходе - колонки. И каждый поток обрабатывает свой токен. Кроме того, для уменьшения коммуникации между варпами (группами потоков), оказывается целесообразным держать куски матриц ключей (Key) и значений (Values) общими для групп поток, а Query свою на варп (в Flash Attention-1 было наоборот). Уменьшение количество операций чтения/записи приводит к дополнительному ускорению.
Результаты
Flash-Attention-2 сравнивается с Flash-Attention из оригинального репозитория, реализации на triton и xformers. Для замеров рассматривают последовательности длиной от 512 до 16k токенов, и слой attention со скрытой размерностью 2048 (64 или 128 голов).
FlashAttention-2 в 1.3-1.5x быстрее на прямом проходе, и до 2x быстрее на обратном проходе по сравнению с Flash-Attention - 1 (особенно велик выигрыш при использовании causal mask). Flash-Attention - 2 использует до 72% теоретической производительности A100. На H100 разница еще заметнее.
Выводы
Данная история поучительна тем, что одна и та же математическая операция в зависимости от реализации, может выполняться принципиально разное время. Замечательный пример того, что насколько учет особенностей железа, время работы различных компонент, сильных и слабых сторон ускорителя вычислений важен при проектировании алгоритмов.
👍2
Stack More Layers Differently: High-Rank Training Through Low-Rank Updates
[Статья][Код]
Обучение всех параметров больших языков моделей весьма прожорливо по памяти из-за необходимости хранить кроме самой тяжеловесной модели еще и состояния оптимизатора (8 байт на параметр).
LoRA, один из самых ходовых методов PEFT, заключающийся в обучении низкоранговых добавок к весам позволяет сильно сэкономить по памяти, демонстрируя при этом хорошее качество при обучении предобученной модели на downstream задачах. Но низкоранговые представления имеют место при дообучении, в то время как для эффективного предобучения на разнообразных данных желательно использовать все имеющуюся в распоряжении емкость сети - то есть обучение должно быть высокоранговым.
В данной статье авторы предлагают метод последовательного обучения низкоранговых добавок к весам линейных слоев нейронной сети с последующим их слиянием с основными весами. И как утверждается, подобная процедура для достаточно больших сетей (самая большая обученная сеть имеет 350M параметров - сущий пустяк по современным меркам), работает ненамного хуже стандартной полноранговой процедуры обучения.
Метод
Ранг суммы двух и более матриц ограничен сверху суммой рангов матриц. Если низкоранговые матрицы в достаточной мере взаимно независимы, то их сумма может иметь значительно больший ранг чем каждое слагаемое по отдельности. Последовательно обучая низкоранговые добавки возможно в итоге добиться высокорангового изменения весов матрицы, В этом и суть метода.
Однако, чтобы метод заработал, авторам пришлось учесть ряд нюансов и применить пару трюков.
Во-первых, используемый при обучении трансформеров Adam хранит скользящие статистики градиентов, и при переходе к обучению новой низкоранговой добавки, если не предпринимать никаких действий, оптимизация будет проводиться в том же подпространстве, что и у предыдущей LoRA добавки, нивелируя всякий смысл в итеративной процедуре. Для предотвращения такого сценария, авторы зануляют 99% состояний оптимизатора с меньшей абсолютной величиной (почему не все? почему не любую другую долю?) при инициализации новой добавки.
Кроме того, learning rate в момент начала обучения новой добавки зануляется и потом быстро разогревается до примерно того же значения, с которым закончила обучение прошлая добавка (используется cosine annealing learning rate). Без короткой warmup фазы обучение расходится.
Предложенная cтратегия именуется ReLoRA.
[Статья][Код]
Обучение всех параметров больших языков моделей весьма прожорливо по памяти из-за необходимости хранить кроме самой тяжеловесной модели еще и состояния оптимизатора (8 байт на параметр).
LoRA, один из самых ходовых методов PEFT, заключающийся в обучении низкоранговых добавок к весам позволяет сильно сэкономить по памяти, демонстрируя при этом хорошее качество при обучении предобученной модели на downstream задачах. Но низкоранговые представления имеют место при дообучении, в то время как для эффективного предобучения на разнообразных данных желательно использовать все имеющуюся в распоряжении емкость сети - то есть обучение должно быть высокоранговым.
В данной статье авторы предлагают метод последовательного обучения низкоранговых добавок к весам линейных слоев нейронной сети с последующим их слиянием с основными весами. И как утверждается, подобная процедура для достаточно больших сетей (самая большая обученная сеть имеет 350M параметров - сущий пустяк по современным меркам), работает ненамного хуже стандартной полноранговой процедуры обучения.
Метод
Ранг суммы двух и более матриц ограничен сверху суммой рангов матриц. Если низкоранговые матрицы в достаточной мере взаимно независимы, то их сумма может иметь значительно больший ранг чем каждое слагаемое по отдельности. Последовательно обучая низкоранговые добавки возможно в итоге добиться высокорангового изменения весов матрицы, В этом и суть метода.
Однако, чтобы метод заработал, авторам пришлось учесть ряд нюансов и применить пару трюков.
Во-первых, используемый при обучении трансформеров Adam хранит скользящие статистики градиентов, и при переходе к обучению новой низкоранговой добавки, если не предпринимать никаких действий, оптимизация будет проводиться в том же подпространстве, что и у предыдущей LoRA добавки, нивелируя всякий смысл в итеративной процедуре. Для предотвращения такого сценария, авторы зануляют 99% состояний оптимизатора с меньшей абсолютной величиной (почему не все? почему не любую другую долю?) при инициализации новой добавки.
Кроме того, learning rate в момент начала обучения новой добавки зануляется и потом быстро разогревается до примерно того же значения, с которым закончила обучение прошлая добавка (используется cosine annealing learning rate). Без короткой warmup фазы обучение расходится.
Предложенная cтратегия именуется ReLoRA.
👍3
Эксперименты
Авторы обучают семейство декодерных моделей моделей от 60 до 350M (типичный размер языковых моделей в 18-19 году) на данных из C4. Архитектура модели повторяет LLaMA.
Процедура обучения состоит из первоначальной фазы полнорангового обучения (т.е обучения всех параметров модели) в течение 5k шагов и 3 циклов обучения низкоранговых добавок на протяжении тех же 5k шагов (с warmup фазой в 100 шагов при переходе к новой LoRA). Пиковый расход памяти такой же, как и в стандартной процедуре обучения.
В качестве бейзлайнов используются:
◦ Стандартное обучение
◦ Обучение меньшей модели с таким же количеством обучаемых параметров, как с LoRA (Control)
◦ LoRA
Метод ожидаемо бьет LoRA, обладая большей выразительностью, и меньшую сеть с тем же числом обучаемых параметров (за исключением самой маленькой модели), при этом несколько уступая стандартной процедуре обучения.
Авторы анализируют спектральное разложение обученных матриц, и у ReLoRA оно больше напоминает изменение весов при обучении всех параметров (по сравнению с LoRA), хоть все еще заметно отличается.
Ablation показывает, что все компоненты метода важны для приемлемого результата - первичная процедура стандартного обучения, зануление состояний отпимизатора и warmup.
Заключение
Довольно интересный и разумный подход. Применимость его в качестве претрейна, по моему мнению, ограничена, из-за необходимости фазы высорангового обучения в начале, из-за чего большие LLM-ки какое-то время придется обучать на множестве хостов. Основной выигрыш может быть при файнтьюнинге на достаточно больших и разнообразных задачах, где выразительности низкоранговых добавок недостаточно.
Авторы обучают семейство декодерных моделей моделей от 60 до 350M (типичный размер языковых моделей в 18-19 году) на данных из C4. Архитектура модели повторяет LLaMA.
Процедура обучения состоит из первоначальной фазы полнорангового обучения (т.е обучения всех параметров модели) в течение 5k шагов и 3 циклов обучения низкоранговых добавок на протяжении тех же 5k шагов (с warmup фазой в 100 шагов при переходе к новой LoRA). Пиковый расход памяти такой же, как и в стандартной процедуре обучения.
В качестве бейзлайнов используются:
◦ Стандартное обучение
◦ Обучение меньшей модели с таким же количеством обучаемых параметров, как с LoRA (Control)
◦ LoRA
Метод ожидаемо бьет LoRA, обладая большей выразительностью, и меньшую сеть с тем же числом обучаемых параметров (за исключением самой маленькой модели), при этом несколько уступая стандартной процедуре обучения.
Авторы анализируют спектральное разложение обученных матриц, и у ReLoRA оно больше напоминает изменение весов при обучении всех параметров (по сравнению с LoRA), хоть все еще заметно отличается.
Ablation показывает, что все компоненты метода важны для приемлемого результата - первичная процедура стандартного обучения, зануление состояний отпимизатора и warmup.
Заключение
Довольно интересный и разумный подход. Применимость его в качестве претрейна, по моему мнению, ограничена, из-за необходимости фазы высорангового обучения в начале, из-за чего большие LLM-ки какое-то время придется обучать на множестве хостов. Основной выигрыш может быть при файнтьюнинге на достаточно больших и разнообразных задачах, где выразительности низкоранговых добавок недостаточно.
LlaMA-2: Open Foundation and Fine-Tuned Chat Models
[Статья][Код]
Не прошло и года (и даже половины года), как запрещенная в России экстремистская организация Meta выпустила новую версию всем полюбившейся LLM-ки: LLaMA-2.
Первая версия модели стала настоящим хитом среди исследователей, практиков, да и простых обывателей, будучи наиболее качественной языковой моделью среди находящихся в публичном доступе. LLaMA стала основой для множества чатботов, получила множество интеграций для запуска на чем угодно начиная от продвинутых GPU и заканчивая калькуляторами и микроволновками.
Нововведения
В плане архитектуры. и процедуры предобучения LLaMA-2 не претерпела значительных изменений.
Вместо стандартного Attention блока, где количество голов в Query, Key, Value проекциях одинаково, и каждому Query соотвествует отдельный Key и Value, используется grouped query attention c 8️⃣ проекциями вместо
Длину контекста увеличили до 4k токенов. RoPE позиционные энкодинги могут работать и с более длинным контекстом.
Данные отфильтровали более тщательно, увеличили в размере на 40% и обучили все модели на 2T токенов вместо 1T.
Итоговая модель на Common Sense Reasoning, Question Answering, World Knowledge, и т.д оказывается лучше прошлой версии и всех других моделей в открытом доступе, но уступает флагманским закрытым - GPT-4, Palm-2-L.
Куда более занимательна (ей же и уделено основное внимание) процедура instruction-finetuning и получения чатбота из языковой модели.
Instruction-finetuning процедура состоит из 2 стадий:
1️⃣ SFT - Supervised Finetuning
2️⃣ RLHF - Reinforcement Learning
Авторы собирают свой собственный датасет из инструкций, в котором акцент был сделан не на количество инструкций, а на их качество и разнообразие (актуальные работы утверждают, что для instruction finetuning данных много и не требуется). В полученном датасете 27540 инструкций.
На первом стадии (SFT) модель обучают на Causal LM, как на этапе преобучения на датасете инструкций. Промты и ответы контатенируют с одну последовательность, разделяя специальным токеном.
Данные для обучения reward модели собирали с помощью человекоподобных разметчиков. Каждый респодент выбирает между двумя вариантами с градацией разницы significantly better, better,
slightly better, or negligibly better/unsure. Для максимизации разнообразия варианты ответов генерируются случайно выбранными моделями из семейства LLaMA-2 с разной температурой.
[Статья][Код]
Не прошло и года (и даже половины года), как запрещенная в России экстремистская организация Meta выпустила новую версию всем полюбившейся LLM-ки: LLaMA-2.
Первая версия модели стала настоящим хитом среди исследователей, практиков, да и простых обывателей, будучи наиболее качественной языковой моделью среди находящихся в публичном доступе. LLaMA стала основой для множества чатботов, получила множество интеграций для запуска на чем угодно начиная от продвинутых GPU и заканчивая калькуляторами и микроволновками.
Нововведения
В плане архитектуры. и процедуры предобучения LLaMA-2 не претерпела значительных изменений.
Вместо стандартного Attention блока, где количество голов в Query, Key, Value проекциях одинаково, и каждому Query соотвествует отдельный Key и Value, используется grouped query attention c 8️⃣ проекциями вместо
num_heads
(т.е каждая Key, Value активация спаривается с num_heads // 8
головами Query). Делать полный multi query с 1 проекций на все головы не стали по двум соображениям - 1) на инференсе они параллелизует вычисления между 8 GPU, и пришлось бы все равно копировать Key, Value между всеми устройствами 2) multi query просаживается по качеству по сравнению с исходным attention, а grouped query имеет примерно то же качество. Данное изменение полезно при авторегрессионой генерации с использованием Key, Value кэшей, так как приводит к заметной экономии в памяти (при той же длине последовательности экономия в num_heads // 8
рвз). Длину контекста увеличили до 4k токенов. RoPE позиционные энкодинги могут работать и с более длинным контекстом.
Данные отфильтровали более тщательно, увеличили в размере на 40% и обучили все модели на 2T токенов вместо 1T.
Итоговая модель на Common Sense Reasoning, Question Answering, World Knowledge, и т.д оказывается лучше прошлой версии и всех других моделей в открытом доступе, но уступает флагманским закрытым - GPT-4, Palm-2-L.
Куда более занимательна (ей же и уделено основное внимание) процедура instruction-finetuning и получения чатбота из языковой модели.
Instruction-finetuning процедура состоит из 2 стадий:
1️⃣ SFT - Supervised Finetuning
2️⃣ RLHF - Reinforcement Learning
Авторы собирают свой собственный датасет из инструкций, в котором акцент был сделан не на количество инструкций, а на их качество и разнообразие (актуальные работы утверждают, что для instruction finetuning данных много и не требуется). В полученном датасете 27540 инструкций.
На первом стадии (SFT) модель обучают на Causal LM, как на этапе преобучения на датасете инструкций. Промты и ответы контатенируют с одну последовательность, разделяя специальным токеном.
Данные для обучения reward модели собирали с помощью человекоподобных разметчиков. Каждый респодент выбирает между двумя вариантами с градацией разницы significantly better, better,
slightly better, or negligibly better/unsure. Для максимизации разнообразия варианты ответов генерируются случайно выбранными моделями из семейства LLaMA-2 с разной температурой.
🔥3
Обучают две reward модели:
1️⃣ Helpfullness (полезность)
2️⃣ Safety (безопасность)
Для моделирования reward используются предобученные чекпоинты с 1-го этапа.
В качестве функции потерь используется бинарная ранжировочная функция потерь из Instruct GPT с добавкой, зависящей от степени увереннности в ответе, чтобы разница в оценках для ответа с большей уверенностью была больше, чем для менее уверенного ответа.
Полученные reward модели сравнивают с теми, что получаются при обучении на других instruction датасетах и GPT4. И по отдельности reward модели оказываются лучше безйлайнов на своих и прочих датасетах (но для GPT4 нет данных на других instruction датасетах).
Затем исследуется scaling поведение от количества данных и размеров модели. Ожидаемо, большие модели и большее количество данных улучшает качество reward модели.
С ростом количества полученных данных от аннотаторов авторы итеративно дообучают reward модель (5-версий) с использованием Proximal Policy Optimization (PPO) и Rejection Sampling.
Нередко перед чатботом ставится задача следовать некоторой инструкции или парадигме поведения на протяжении нескольких раундов вопрос-ответ или всего диалога. Чтобы поддерживать в модели подобный сценарий поведения, авторы статьи используют метод GAtt (Ghost Attention). Ко всем запросам пользователя добавляется целевая инструкция, но чтобы не нарушать распределение данных (диалог, где пользователь повторяет одну и ту же инструкцию много раз смотрится неестественно), лосс от прошлых сообщений в диалоге не учитывается.
Данная модификация действительно способствует следованию ассистентом целевой инструкции.
1️⃣ Helpfullness (полезность)
2️⃣ Safety (безопасность)
Для моделирования reward используются предобученные чекпоинты с 1-го этапа.
В качестве функции потерь используется бинарная ранжировочная функция потерь из Instruct GPT с добавкой, зависящей от степени увереннности в ответе, чтобы разница в оценках для ответа с большей уверенностью была больше, чем для менее уверенного ответа.
Полученные reward модели сравнивают с теми, что получаются при обучении на других instruction датасетах и GPT4. И по отдельности reward модели оказываются лучше безйлайнов на своих и прочих датасетах (но для GPT4 нет данных на других instruction датасетах).
Затем исследуется scaling поведение от количества данных и размеров модели. Ожидаемо, большие модели и большее количество данных улучшает качество reward модели.
С ростом количества полученных данных от аннотаторов авторы итеративно дообучают reward модель (5-версий) с использованием Proximal Policy Optimization (PPO) и Rejection Sampling.
Нередко перед чатботом ставится задача следовать некоторой инструкции или парадигме поведения на протяжении нескольких раундов вопрос-ответ или всего диалога. Чтобы поддерживать в модели подобный сценарий поведения, авторы статьи используют метод GAtt (Ghost Attention). Ко всем запросам пользователя добавляется целевая инструкция, но чтобы не нарушать распределение данных (диалог, где пользователь повторяет одну и ту же инструкцию много раз смотрится неестественно), лосс от прошлых сообщений в диалоге не учитывается.
Данная модификация действительно способствует следованию ассистентом целевой инструкции.
Результаты
LlaMA-2-chat уверенно побеждает чатботов, основанных на моделях в открытом доступе, сопоставимых размеров, и с небольшим отрывом оказывается лучше (с точки зрения человеческих предпочтений) чем ChatGPT при оценке helpfulness на собранных Meta 4k инструкциях.
При обучении на safety данных, с ростом количества safety данных стабильно уменьшается доля небезопасных ответов без просадки по метрике полезности.
По safety (доле небезопасных ответов) и общему рейтингу полезности и безопасности LlaMA-2 чатботы опережают конкуретных открытых чатботов и ChatGPT/PaLM при оценке на собственном бенчмарке из 2k промптов.
Из дополнительных экспериментов авторы показывают, что модель можно научить действовать корректно подав инструкцию относящуюся к заданному времени (например, модель не будет знать ответ на то, кто побелил во Второй мировой войне, если бы запрос был адресован в 1940 году) и хорошо взаимодействует с ToolFormer.
Итог
LLaMA-2 - новая SOTA среди моделей в открытом доступе, и с учетом бешеного прогресса в области, большого интереса в DL-сообществе, за несколько дней с выпуска, народ уже успел изрядно поиграться с моделью, покрутить и повертеть ее. Данная работа - труд скорее инженерный, чем научный, но, безусловно, полезный и важный. Приятное отличие от первой версии, где месяцами можно было ждать одобрения на скачивание весов (хотя все кому надо воспользовались пиратками), в том, что запрос на LlaMA-2 удовлетворяется оперативно (обычно в течение пары часов).
LlaMA-2-chat уверенно побеждает чатботов, основанных на моделях в открытом доступе, сопоставимых размеров, и с небольшим отрывом оказывается лучше (с точки зрения человеческих предпочтений) чем ChatGPT при оценке helpfulness на собранных Meta 4k инструкциях.
При обучении на safety данных, с ростом количества safety данных стабильно уменьшается доля небезопасных ответов без просадки по метрике полезности.
По safety (доле небезопасных ответов) и общему рейтингу полезности и безопасности LlaMA-2 чатботы опережают конкуретных открытых чатботов и ChatGPT/PaLM при оценке на собственном бенчмарке из 2k промптов.
Из дополнительных экспериментов авторы показывают, что модель можно научить действовать корректно подав инструкцию относящуюся к заданному времени (например, модель не будет знать ответ на то, кто побелил во Второй мировой войне, если бы запрос был адресован в 1940 году) и хорошо взаимодействует с ToolFormer.
Итог
LLaMA-2 - новая SOTA среди моделей в открытом доступе, и с учетом бешеного прогресса в области, большого интереса в DL-сообществе, за несколько дней с выпуска, народ уже успел изрядно поиграться с моделью, покрутить и повертеть ее. Данная работа - труд скорее инженерный, чем научный, но, безусловно, полезный и важный. Приятное отличие от первой версии, где месяцами можно было ждать одобрения на скачивание весов (хотя все кому надо воспользовались пиратками), в том, что запрос на LlaMA-2 удовлетворяется оперативно (обычно в течение пары часов).