tgoop.com/nn_for_science/2376
Last Update:
Как обучить диффузионную модель с нуля за $1890?
Законы масштабирования в генеративном ИИ повышают производительность, но есть ньюанс: разработка моделей концентрируется среди игроков с большими вычислительными ресурсами.
Поскольку стоимость обучения text-to-image трансформера растет с количеством участков в каждом изображении, исследователи из Sony AI предложили случайным образом маскировать до 75% участков изображения во время обучения.
Применяется стратегия отложенного маскирования, которая предварительно обрабатывает все участки с помощью
микшера участков перед маскированием, тем самым значительно снижая ухудшение производительности процесса. Для оптимизации вычислительных затрат данный подход со работает лучше, чем уменьшение масштаба модели.
В исследование также включили последние
улучшения в архитектуре трансформеров, такие как использование слоев с mixture of experts (MoE),чтобы улучшить производительность и убедиться в важности использования синтетических изображений для уменьшения затрат на обучение.
Какие результаты?
Используя только 37 млн изображений (22 млн реальных + 15 млн синтетических), была обучена модель типа "sparse transformer" с 1,16 млрд параметров.
На обучение было потрачено всего 1890$ !
Была достигнута производительность 12,7 FID при zero shot learning на наборе данных COCO.
Примечательно, что модель достигает конкурентоспособного FID и высококачественных генераций, при этом требуя в 118 раз меньших затрат, чем стабильные диффузионные модели, и в 14 раз меньших затрат, чем текущий современный подход, который стоит 28400$
🔍 Технические детали:
• Архитектура: sparse DiT-XL/2 трансформер
• Вычисления: 8×H100 GPU на 2,6 дня тренировки
• VAE: использование как SDXL-VAE (4 канала), так и Ostris-VAE (16 каналов)
• Патч-миксер перед трансформером + маскирование 75% патчей
• Обучение: 280K шагов на 256×256, затем 55K шагов на 512×512
• Размер батча: 2048, с применением центрального кропа
📊 Доступные предобученные модели:
1. MicroDiT_XL_2 на 22 млн реальных изображениях (FID 12.72)
2. MicroDiT_XL_2 на 37 млн изображениях (FID 12.66) с SDXL-VAE
3. MicroDiT_XL_2 на 37 млн изображениях (FID 13.04) с Ostris-VAE
4. MicroDiT_XL_2 на 490 млн синтетических изображениях (FID 13.26)
💻 Репозиторий содержит полный код, включая обработку датасетов и тренировочные конфиги для каждого этапа
🔗 Статья