QUANT_PRUNE_DISTILL Telegram 338
MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

[Статья] [Страница проекта] [Код][Пост на Machine Learning]

Введение

2:4 (она же semistructured sparsity) дает какое-никакое ускорение на GPU от Ampere и новее. Однако, просадка при прунинге обычно слишком велика для LLMок, дабы быть интересной на практике.

В этой статье предлагают метод обучения хороших 2:4 масок через Gumbel-Softmax.

Метод

Маска суть дискретная сущность потому ее просто так не отпизируешь градиентным спуском, и авторы предлагают моделировать распределение масок через Gumbel-Softmax с 6 = binom(2, 4) вариантам. На обучении оптимизируются логиты вероятности сэмплирования одного из вариантов масок (т.е маска есть взвешенная сумма возможных вариантов), а на инференсе берется наиболее вероятный. Обучение суть просто оптимизация кросс-энтропии (как на pretrain). Веса при этом заморожены.

Если какой-то вес зануляется или близок к нулю, то логиты маски почти не получают градиентов, потому авторы добавляют регуляризационный член как weight_decay, но со знаком , чтобы расталкивать веса от нуля, тем самым поддерживая не нулевую норму у немаскированных весов.

Кроме того, маски полученные условным SparseGPT/Wanda являются хорошей инициализацией для масок и позволяют чуть улучшить результат.

Эксперименты

Метод валидируют на 🦙-2, Nemotron-4 15B и двух маленьких проприетарных GPT-3. Замеряют по классике перплексию на Wikitext и 0-шоты.

По метрикам опережают уверенно все бейзлайны (SparseGPT, Wanda, Magnitude). SparseGPT, правда, можно завести и получше. В отличие от алгоритмов one-shot прунинга, которые быстро насыщаются от количества данных, MaskLLM продолжает улучшаться при большем и большем количестве данных, что неудивительно ибо это есть по сути метод оптимизации с большим количеством обучаемых параметров.

Ablations:
1️⃣ Инициализация маской от one-shot прунера накидывает в конечном качестве.
2️⃣ Достаточная степень стохастичности сэмплирования важна для хорошего качества, дабы модель могла “попробовать” разные варианты масок.
3️⃣ Анти-weight decay не то чтобы сильно, но улучшает качество.
4️⃣ Кроме того, полученную маску можно оптимизировать на downstream и даже временами оверфитнуться улучшить перплексию по сравнению с floating-point моделью.

Вывод

Вполне годная стратегия для обучения 2:4, но требующая определенных вычислительных затрат (т.е прилично дороже чем прогнать SparseGPT). Результат достойный, но все же просадка остается довольно заметной - больше чем у SOTA методов 2-битной квантизации. Вероятно, если еще оптимизировать веса вместе с масками - можно выжать больше.
👍62



tgoop.com/quant_prune_distill/338
Create:
Last Update:

MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models

[Статья] [Страница проекта] [Код][Пост на Machine Learning]

Введение

2:4 (она же semistructured sparsity) дает какое-никакое ускорение на GPU от Ampere и новее. Однако, просадка при прунинге обычно слишком велика для LLMок, дабы быть интересной на практике.

В этой статье предлагают метод обучения хороших 2:4 масок через Gumbel-Softmax.

Метод

Маска суть дискретная сущность потому ее просто так не отпизируешь градиентным спуском, и авторы предлагают моделировать распределение масок через Gumbel-Softmax с 6 = binom(2, 4) вариантам. На обучении оптимизируются логиты вероятности сэмплирования одного из вариантов масок (т.е маска есть взвешенная сумма возможных вариантов), а на инференсе берется наиболее вероятный. Обучение суть просто оптимизация кросс-энтропии (как на pretrain). Веса при этом заморожены.

Если какой-то вес зануляется или близок к нулю, то логиты маски почти не получают градиентов, потому авторы добавляют регуляризационный член как weight_decay, но со знаком , чтобы расталкивать веса от нуля, тем самым поддерживая не нулевую норму у немаскированных весов.

Кроме того, маски полученные условным SparseGPT/Wanda являются хорошей инициализацией для масок и позволяют чуть улучшить результат.

Эксперименты

Метод валидируют на 🦙-2, Nemotron-4 15B и двух маленьких проприетарных GPT-3. Замеряют по классике перплексию на Wikitext и 0-шоты.

По метрикам опережают уверенно все бейзлайны (SparseGPT, Wanda, Magnitude). SparseGPT, правда, можно завести и получше. В отличие от алгоритмов one-shot прунинга, которые быстро насыщаются от количества данных, MaskLLM продолжает улучшаться при большем и большем количестве данных, что неудивительно ибо это есть по сути метод оптимизации с большим количеством обучаемых параметров.

Ablations:
1️⃣ Инициализация маской от one-shot прунера накидывает в конечном качестве.
2️⃣ Достаточная степень стохастичности сэмплирования важна для хорошего качества, дабы модель могла “попробовать” разные варианты масок.
3️⃣ Анти-weight decay не то чтобы сильно, но улучшает качество.
4️⃣ Кроме того, полученную маску можно оптимизировать на downstream и даже временами оверфитнуться улучшить перплексию по сравнению с floating-point моделью.

Вывод

Вполне годная стратегия для обучения 2:4, но требующая определенных вычислительных затрат (т.е прилично дороже чем прогнать SparseGPT). Результат достойный, но все же просадка остается довольно заметной - больше чем у SOTA методов 2-битной квантизации. Вероятно, если еще оптимизировать веса вместе с масками - можно выжать больше.

BY КПД


Share with your friend now:
tgoop.com/quant_prune_distill/338

View MORE
Open in Telegram


Telegram News

Date: |

“Hey degen, are you stressed? Just let it all out,” he wrote, along with a link to join the group. Telegram message that reads: "Bear Market Screaming Therapy Group. You are only allowed to send screaming voice notes. Everything else = BAN. Text pics, videos, stickers, gif = BAN. Anything other than screaming = BAN. You think you are smart = BAN. The court said the defendant had also incited people to commit public nuisance, with messages calling on them to take part in rallies and demonstrations including at Hong Kong International Airport, to block roads and to paralyse the public transportation system. Various forms of protest promoted on the messaging platform included general strikes, lunchtime protests and silent sit-ins. Invite up to 200 users from your contacts to join your channel Activate up to 20 bots
from us


Telegram КПД
FROM American