tgoop.com/gonzo_ML/2938
Last Update:
Were RNNs All We Needed?
Leo Feng, Frederick Tung, Mohamed Osama Ahmed, Yoshua Bengio, Hossein Hajimirsadegh
Статья: https://arxiv.org/abs/2410.01201
Продолжение возрождения рекуррентных сетей. На сей раз снова классика (RNN/LSTM/GRU), а не новомодные SSM (которые ещё и не эквивалентны RNN, в смысле находятся в более простом классе сложности, см. https://www.youtube.com/watch?v=4-VXe1yPDjk).
RNN обладают фундаментальными преимуществами в виде требований к памяти. Они линейны (от размера последовательности) при обучении и константны при инференсе. Чего не скажешь про ванильные трансформеры, у которых квадратичная и линейная сложности соответственно. Один только большой минус был у RNN -- обучение не параллелилось. Обучались они последовательно через backpropagate through time (BPTT), что для длинных последовательностей очень медленно. Здесь преимущества трансформеров в виде параллелизации при всех их недостатках относительно сложности оказалось достаточно, чтобы их обучение скейлилось, и вот мы там где мы есть -- трансформеры вытеснили рекуррентные сети из своих экологических ниш и доминируют почти везде.
Работы последних пары лет устраняют этот недостаток RNN, на свет появились LRU, Griffin, RWKV, Mamba и прочие. Всё это современное разнообразие эффективно параллелится с помощью одного и того же алгоритма -- parallel prefix scan, он же parallel prefix sum (https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf).
Авторы текущей работы адаптируют parallel scan для старых добрых LSTM/GRU, устраняя зависимость input, forget и update гейтов от скрытого состояния (H). Заодно и tanh нелинейность тоже убирают (привет, LRU!). Обычные ванильные RNN не рассматривают, ссылаясь на проблемы с затухающими и взрывающимися градиентами (но вспомним LRU от DeepMind, он как раз был вариацией обычной RNN, https://www.tgoop.com/gonzo_ML/1734).
У LSTM, кстати, тоже было 100500 разных вариантов, например, peephole connections с дополнительной зависимостью гейтов от содержимого ячейки памяти -- помните, у LSTM по факту две переменные состояния, внутреннее состояние ячейки памяти (C, не видно снаружи) и скрытое состояние (H, hidden state, которое снаружи видно). По LSTM, на мой взгляд, есть два фундаментальных источника информации кроме оригинальных статей. Один -- это PhD диссертация Феликса Герса (Felix Gers, http://www.felixgers.de/papers/phd.pdf), который и добавил в архитектуру forget gate (изначально было только два других гейта) + peephole connections. Второй -- PhD диссертация Алекса Грейвса (Alex Graves, https://www.cs.toronto.edu/~graves/phd.pdf), который придумал CTC loss и многомерные RNN. Сила хороших PhD. Ну да ладно.
Авторы получают минималистичные версии LSTM и GRU (minLSTM и minGRU соответственно), которые требуют меньше параметров, параллелятся при обучении и дают хорошее качество. Надо вспомнить, что было в истории много других заходов на рекуррентные сети с быстрым параллельным обучением. Например, QRNN (https://arxiv.org/abs/1611.01576, она более отличается благодаря наличию свёрток) или SRU (https://arxiv.org/abs/1709.02755).
По сути работы авторы посмотрели на оригинальные архитектуры LSTM и GRU и убрали вещи, которые мешали реализовать их обучение через parallel scan.
В GRU убралась зависимость update gate (z) и скрытого состояния (h) от предыдущего значения h. Reset gate ушёл совсем. Затем ушёл tanh при вычислении финального значения h. Теперь нужно O(2*d_h*d_x) параметров вместо O(3*d_h(d_x + d_h)) в оригинальном GRU.
В LSTM также ушла зависимость от предыдущего состяния h в forget и input гейтах, а также в содержимом ячейки памяти c. Из вычисления c также ушёл tanh, и в итоге дропнули output gate и саму ячейку c, осталось только h. minLSTM требует O(3*d_h*d_x) параметров вместо O(4*d_h(d_x + d_h)) в LSTM.
По времени вычисления новые модели minLSTM/minGRU сравнимы с Mamba и, например, на длине последовательности в 512 элементов быстрее оригиналов в LSTM/GRU в 235 и 175 раз. На больших длинах ещё солиднее.
BY gonzo-обзоры ML статей
Share with your friend now:
tgoop.com/gonzo_ML/2938