tgoop.com/quant_prune_distill/268
Last Update:
BitsFusion: 1.99 bits Weight Quantization of Diffusion Model
[Статья][Ридми]
Yet another свежая статья про экстремальную квантизацию диффузионок от Snap Research. Утверждают, что 2-битная модель даже лучше 😱 исходной. Если у вас завалялся старый игровой ПК, года так 2007, можно его не выбрасывать, а генерить картинки стейбл диффузией 😅.
Метод
По существу предложенный метод представляет собой QAT (квантизейшн эвер трэйнинг) с mixed-precision квантизацией и дистилляцией.
Каждый слой квантуется в 1️⃣, 2️⃣, 3️⃣ бита, или не квантуется вовсе. В прошлых работах было показано, что диффузионные модели очень чувствительны к сжатию timestep ⏳ проекций, обуславливающих на текущий шаг расшумления. Поэтому следуя стандартной практике их не сжимают. Авторы анализируют чувствительность разных слоев к сжатию, замеряя MSE с картинкой сгенерированной оригинальной моделью и CLIP Score, при квантовании слоев по отдельности и замечают, что данные две метрики не всегда скоррелированны. В частности, сжатие слоев в cross attention слоях не сильно так сильно влияет на MSE, но при этом временами ломает семантику изображения. Shortcut свертки очень важны.
Каждому слою сопоставляется важность в зависимости от MSE и числа параметров в нем, и подбирается порог такой, что достигается целевая степень сжатия.
Min-Max стратегия по определению масштабов квантования не учитывает наличие выбросов в распределении весов, поэтому авторы применяют Lloyd-Max итеративный алгоритм для минимизации ошибки. Кроме того, важно учитывать симметрию весов относительно нуля и явно накладывать ее.
Далее авторы дообучают 👨🏫 квантованную модель.
На первой стадии сжатая модель пытается воспроизвести выход и промежуточные активации учителя. Авторы отмечают, что важно учитывать classifier-free guidance, используемый при инференсе. Распределение шагов сэмплирования смещено в область где ошибка квантизации больше (поздние шаги диффузии).
На второй стадии модель учится на noise prediction, как самая обычная диффузионка.
Эксперименты
Берут SD v1.5 модель и учат 20к шагов на первой стадии, и 50к на второй на некотором проприетарном датасете. Замеряют CLIP Score и FID на MS-COCO, TIFA, GenEval и пользовательские предпочтения на PartiPrompts.
Сжатая модель после второй стадии по метрикам примерно равна несжатой.
На SbS (side-by-side сравнении) на Parti Prompts BitsFusion модель побеждает SDv1.5 с win rate 54.4% против 45.6%. Странно, что SbS, который самый показательный, учитывая несовершенство текущих генеративных метрик, скромно запрятан в приложение.
В ablation показывают, что более-менее все компоненты метода важны - смешанная точность, дистилляция признаков, сэмплирование шагов и двустадийное обучение.
Вывод
Довольно хороший инженерный результат, использующий сочетание разных идей. Разве, что без специализиованного железа вряд ли удастя выжать ускорение. Однако, вызывает ❔, почему была выбрана SD v1.5, хотя статья свежая, и уже почти как год существует SDXL. Можно ли их так же легко сжать? Полагаю, что хорошее качество еще во многом обусловлено тем фактом, что загадочный проприетарный датасет неплохо отфильтрованных и дообучение несжатой модели могло бы ее тоже улучшить, ибо SD v1.5 училась на довольно шумных данных из LAION.
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/268