tgoop.com/quant_prune_distill/181
Last Update:
TinyGSM: achieving > 80% on GSM8k with small language models
[Статья][Кода нет, как и моделей и датасета]
Введение
Обучить здоровенную модель на здоровенном датасете здоровенное число итераций, которая умеет во все и всея - большого ума не надо, поставил обучаться на десятках тысяч видеокарт на несколько месяцев - и готово. А вот получить небольшую модельку, способную в сложные логические конструкции и заключения, решение задач по математике - вот это уже настоящее мастерство и искусство.
В разбираемой статье авторы обучили семейство сетей небольшого размера (от 125M до 1.3B), которые превосходят по доле правильных решений куда более крупных конкурентов.
Метод
В великом множестве статей ранее было показано, что синтетические данные, сгененированные могучей сетью а-ля GPT-3.5/4, позволяют добиться значительно более высокой эффективности обучения по сравнению с типичным корпусами, собранными из интернета. В частности, можно дотюнить (Alpaca, Vicuna, Wizard, Platypus и многое и многое другое), или обучить с нуля (серия моделей Microsoft-Phi и TinyStories).
Математические датасеты невелики по размеру. GSM-8k, в частности, имеет всего 7k примеров в обучающей выборке, и наивное обучение приведет к переобучению. Потому авторы аугментируют данные с помощью GPT-3.5, перефразируя вопросы и добавляя нерелевантный контекст. Чтобы обеспечить качество данных, убирают слишком короткие и не содержащие числа задачи. Кроме того, выфильтровали задачи, которые оказались похожи на тестовые по n-грамному сравнению. Итого вышло более 12M синтетических задач. (~1.8B токенов)
Однако одно лишь это не позволяет преодолеть порог в 70% top-1 accuracy на тесте GSM-8k.
Следующим краеугольным камнем работы является использование модели-verifier (проверятеля), которая оценивает корректность каждого шага решения. Мотивация заключается в том, что одно решение может быть неправильным, но если сгенерировать несколько, то хотя бы одно да залетит. И модель-verifier учится определять корректные шаги в решении.
Обучают ее следующим образом - берут 3 чекпоинта модели, решающей GSM-8k, с разных итераций обучения (более ранние больше ошибаются), и учат предсказывать корректность конкретного токена в решении. Если модель-решатель решила задачу верно, то все токены в последовательности размечаются, как верные - label 1, иначе - наоборот, вся последовательность, как неправильная. Сэмплируют 48 решений на каждый из 7к примеров в обучающей выборке.
Таким образом рецепт успеха состоит из двух основных идей:
1⃣️️️️️️ Аугментация датасета
2⃣️️️️️️ Обучение модели - проверятеля, оценивающей правильность конкретных шагов совместно с моделью, решающей задачи.
BY КПД
Share with your friend now:
tgoop.com/quant_prune_distill/181