tgoop.com/quant_prune_distill/475
Last Update:
Метод
Знание всей прошлой траектории сэмплирования при малошаговом сэмплировании потенциально дает более богатую информацию модели и позволяет скорректировать ошибку. Формально существует однозначное отображение (при детерминированном сэмплере) между шумом и конечной генерацией и хотелось бы быть как можно ближе к этой траектории.
Однако, возникает вопрос - каким образом можно подать прошлые шаги в модель?
И авторы предлагают интересное решение:
🎯 Модель на текущем шаге генерации делает Self Attention на текущий и прошлые шаги сэмплирования с причинной маской (прошлые шаги не аттендятся на будущие).
🎯 В модель добавляется дополнительный эмбеддинг на шаг зашумления. Разные шаги зашумления при генерации получают разные временные эмбеддинги. Пространственные при этом одинаковы.
🎯 Обусловливание на прошлые шаги проводится только в первых N трансформерных блоках. C одной стороны, Attention на прошлые шаги довольно шумный в поздних слоях и даже слегка просаживает качество. В то же время, Self Attention только в первых блоках удешевляет forward pass.
🎯 Выходы KV-кэшей с прошлых шагов можно закэшировать, как в авторегрессионных моделях. Дополнительные вычисления возникают только при вычислении непосресдственно Self Attention.
В качестве базового метода дистилляции используют простой Step Distillation (без Progressive), где модель-ученик пытается воспроизвести траекторию учителя. Для улучшения качества можно дополнительно накинуть адверсариальный лосс на x0.
Также предлагаются два альтернативных подхода маскирования:
1️⃣ Скользящее окно attention (обусловливание на несколько последних шагов)
2️⃣ Attention на текущий сэмпл и начальный шум.
Эксперименты
Метод валидируют на DiT-XL/2 для class-conditional генерации на ImageNet (256x256) и на проприетарной EMU модели для text-2-image. Про последнюю известно, что это трансформер на 1.7B параметров, обученный на некотором проприетарном датасете.
Для дистилляции DiT-XL/2 учителем сэмплируют ~2.5M траекторий при 25 шагах сэмплирования при этом стремясь добиться качественной 4-шаговой генерации. Качество оценивают по FID (на каком количестве сэмплов?), IS, Precision и Recall.
Обусловливание на траекторию значительно улучшает качество по сравнению с ванильной Step Distiilation. Альтернативные варианты масок будто бы чуть хуже по метрикам, но возможно, не статзначимо. GAN-лосс сильно улучшает качество и конечная модель имеет даже меньший FID, чем учитель.
В конечном варианте модель учитывает явно прошлые шаги в первых 6 блоках из 28-ми, а далее работает как исходный DiT.
При генерации 256x256 дополнительный condition на прошлые шаги (несмотря на увеличение количества токенов в Self-Attention на последнем шаге генерации почти в 4 раза) почти не замедляет генерацию. End-to-end время генерации возрастает только на 2% по сравнению с инференсом, использующим только текущий сэмпл. Однако, здесь стоит заметить, что для 256x256 последовательность токенов довольно короткая - (256 = (256/8/2)^2
токенов на одно изображение, т.е 1024 на последнем шаге). Потому вычисление Attention сравнительно недорогое и дополнительный оверхед (благодаря kv-кэшам и включению прошлых шагов только в первых блоках трансформера) должен быть действительно невелик.
На text-2-image генерации качество оценивают на CompBench - бенчмарке, оценивающем релеватность и сравнивают с другими публичными и непубличными дистилированными моделями - SDXL-Lightning, DMD2, ADD, LCM-LoRA, и ImagineFlash и по FID на MSCOCO.
Предложенный подход ARD достигает самого хорошего качества при 3-шаговой генерации слегка опережая 4-шаговую DMD2 по метрикам.
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/475