tgoop.com/quant_prune_distill/219
Last Update:
fp8 и с чем его едят
[Статья][Док в transformer.engine]
Введение
Невероятный прирост производительности каждого нового поколения Nvidia на графиках в презентациях GTC кроме объективного улучшения архитектуры обусловлен еще и переходом к типам все более низкой ⬇️ точности (что было остроумно подмечено Love. Death. Transformers.).
Как известно, инферить нейросетки с весами низкой точности можно и нужно. Но, вопрос в том, насколько сложно обучать в числах с плавающей точкой пониженной точности.
Наивное обучение в fp16, без automated mixed precision, как известно, нередко приводит к неприятным сюрпризам типа NaN в лоссах и градиентах. Градиенты могут быть слишком большими, чтобы представляться в fp16, либо слишком маленькими, и адаптивно подбираемый loss scale сдвигает гистограмму градиентов в нужный диапазон значений. Есть еще bf16 (поддерживаемый, начиная с архитектуры Ampere), который имеет более широкий диапазон, но большую погрешность представлений.
И в семействе Hopper добавили поддержку вычислений с fp8.
Метод
fp8 - на самое деле, не 1️⃣, а 2️⃣ типа.
E4M3 c 4 битами на экспоненту и 3 на мантиссу используется для весов и активаций, которые обычно лежат в сравнительно узком диапазоне и важнее точность представления, чем возможность принимать экстремальные значения. Бесконечность не представлена.
E5M2 с 5 битами на экспоненту и 3 на мантиссу используется для градиентов и состояний оптимизатора, где допустима большая погрешность в представлении и больше разброс принимаемых значений.
Обычный loss scaling как в half precision не работает, и приходится иметь для каждого тензора адаптивный масштаб (как в квантизации), чтобы загнать его в удобный диапазон. Определять его на лету накладно. И потому предлагается хранить некоторое скользящее среднее максимумов и на него масштабировать.
Эксперименты
Авторы тестируют эффективность обучения в fp8 на ImageNet-1k, LSTM на WMT и претрейне GPT-3 подобного трансформера на Pile.
fp8 почти везде смог показать конечное качество как half precision бейзлайн за исключением MobileNet-v2, где точность просела на 0.5%.
На языковых задачах fp8 модели достигают примерно того же качества, что и half precision.
В ablation показывают, что адаптивный масштаб для каждого тензора при конвертации bfloat16 модели в fp8 важен, иначе заметно проседает качество даже при оптимальном выборе сдвига экспоненты. Предобученный half-precision BERT без проблем представляется в fp8.
Вывод
fp8 тип выглядит вполне полезным и перспективным. Однако, обучение LLMок с ним, по всей видимости, требует дополнительной возни и несколько вагонов с H100 (и маленькую тележку). Потому на текущий момент, известные открытые модели обучались в half precision. Вероятно, OpenAI и Anthropic что-то пробовали шаманить в fp8, но кто об этом расскажет…
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/219