tgoop.com/quant_prune_distill/266
Last Update:
TerDiT: Ternary Diffusion Models with Transformers
[Статья][Код инференса]
LLMки не квантовал за последние год-два только ленивый, потому назрело время осваивать и другие ниши. Квантование диффузионных моделей на текущий момент пока не столь исследовано, как LLMки, в связи с тем, что сами модели не доросли до своих собратьев из NLP, потому и не было столь острой необходимости. Тем не менее, прогресс не стоит на месте, и стоит быть готовым к дальнейшему масштабированию диффузионных тушек.
В рассматриваемой статье авторы перенесли метод тернарной квантизации (quantization-aware-training) QAT из BitNet1.58 на DiTы для class-conditional генерации на ImageNet. Квантуют только веса (активации остаются в исходной точности).
Метод
По существу ничего нового по сравнению с BitNet1.58, веса обучаются через straight-through estimator (STE) с большим learning rate.
Единственное нововведение - нормализация на выходе AdaLayerNorm. Авторы обнаружили, что тернарные веса выдают большие активации, и scale/shift/gate модуляции слишком велики, чтобы сетка могла нормально сходиться. Навешивание RMSNorm на конец MLP для получения модуляций решает проблему.
Эксперименты
Метод валидируют на DiTах двух размеров - с 600M параметров (примерно как DiT-XL из оригинальной статьи) и 4.2B параметров - на class-conditional генерацию на ImageNet 256x256.
По метрикам, тернарная 4.2B модель примерно равна DiT-XL, 600M несколько хуже. То есть для большой модели близкое качество к floating point модели при чуть меньшем общем размере модели (параметров больше в 7 раз, бит на параметр ~10 меньше, чем в fp16). Справедливости ради, стоит заметить что TerDiT обучался меньшее число итераций по сравнению с моделью из статьи фейсбука.
С инференсом немного грустненько 😢 получилось. Для работы с тернарными весами берут кернелы из HQQ и деквантизуют на ходу. Квантованные модели медленнее 😱 fp32 на 20-25%, а при опущенном сравнении с fp16 замедление было бы порядка 3 раз. Зато неплохая экономия по памяти. 4.2B моделька есть 3Gb видеопамяти на пике при инференсе.
В приложении еще зачем-то показывают что существующие 4-битных квантизации ломают полностью DiT. Берут, правда SmoothQuant, который в отсутствие квантования активаций, вырождается в round-to-nearest (RTN), т.е самый наивный и грубый метод, при существовании куда более сильных PTQ методов для диффузии (Q-Diffusion, PTQ4DM).
Вывод
С одной стороны, очередное подтверждение того, что тернарный QAT как-то да работает. Однако результат куда скромнее того, что получили для LLM майкрософты, и с таким замедлением инференса вряд ли интересен практикам. Неизвестно, масштабируется ли он на случай более сложной задачи text-2-image генерации. Тем не менее деятельности представляет определенный интерес, и развитием эффективных алгоритмов QAT, вероятно, тернарные модели вполне могут быть около Парето-оптимальными. Во всяком случае, в некоторых приложениях.
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/266