tgoop.com/quant_prune_distill/287
Last Update:
Fast as CHITA: Neural Network Pruning with Combinatorial Optimization
[Статья][Без Кода]
CHITA - это гепард, а не административный центр Забайкальского края.
Статья 2023 года от Google Research про прунинг.
Введение
Optimal 🧠 Surgeon, использующий приближение 2️⃣ порядка для определения оптимальной сжатой конфигурации, лежит в основе многих методов прунинга и квантизации. Алгоритм Iterative Hard Thresholding (IHT), в котором на каждом шаге делается шаг оптимизации и прунинг самых маленьких весов, лежит в основе некоторых методов sparse training - в частности, AC/DC.
В данной статье решили поженить OBS и IHT и назвали полученную сущность CHITA (Combinatorial Hessian-free Iterative Thresholding Algorithm).
Метод
Основная проблема с Гессианом нейронной сети, что его хрен посчитаешь честно. Существуют различные приближения; в данной статье опираются на Фишеровское (оценка Гессиана сумой внешних произведений градиентов). Обычно градиентов (n)
много меньше, чем параметров (d)
, потому полученная матрица низкоранговая. Можно ее даже не материализовывать, так как в итоговом алгоритме потребуются только матрично-векторные произведения, а хранить лишь n
градиентов в матрице A \in R^{n x d}
. Фишеровская матрица выражается как A^T A
.
CHITA на каждом шаге делает шаг градиентного спуска для квадратичного разложения в окрестности оптимума (т.е L(w) = L_0 + g^ T + 1/2 w^T H w
) с последующим прореживанием как в IHT. Метод требует времени и памяти линейной по числу параметров и сохраненных градиентов.
Оптимальный шаг оптимизации (learning rate в IHT) находят с помощью умного 🧠 алгоритма поиска на прямой.
Эксперименты
Метод валидируют на небольших CNNках на CIFAR-10 и ImageNet. Метод быстрее M-FAC и при этом достигает лучшего качества. Однако честность сравнения в one-shot pruning, учитывая, что метод итеративный, в отличие от M-FAC, вызывает вопросы.
Метод можно применять итеративно, постепенно повышая степень сжатия и дообучая сеть с фиксированной маской. Лучше повышать линейно уровень прореживания линейно, чем экспоненциально или за раз пытаться сжать все.
При итеративном сжатии удается добиться 1% просадки на MobileNet-v1 при 75% cжатии и 5% просадки при 89%. В этом режиме почему-то забывают в сравнениях про M-FAC 😂, который достигает близких результатов по качеству. Метод довольно неплох и на ResNet50 - с умеренными просадками при 90-95% sparsity.
Выводы
Метод, неплох и не слишком сложен в реализации. Основная проблема - необходимость хранить много градиентов, что не позволяет масштабировать метод на LLMки и интересные в 2024 году модели. У вашего покорного слуги есть одна идейка, как это обойти. Но это пока секрет)
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/287