tgoop.com/ai_tablet/160
Last Update:
TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling
Новая Sota (ли ?) от Яндекса в табличных задачах, TabM, — это MLP-архитектура, которая имитирует ансамбль из k (в статье 32) сетей. Она делает несколько предсказаний на один объект, а подмодели обучаются одновременно и разделяют большую часть весов почти как в BatchEnsemble. Это позволило отказаться от attention, ускорить обучение и улучшить метрики за счет ассемблирования. В статье утверждают что метрики лучше чем у бустингов, но кажется метрики стат. значимо не отличаются. Забавно что excel сильнее базового MLP из чего и состоит текущее решение
Протестировал сравнение метрик базового LightGBM и этого решения. Это было, конечно, намного легче, чем с TabR (прошлая Sota от Яндекса), код которого в виде библиотеки не выкладывали, но всё равно из коробки модель не обучалась. Пришлось взять параметры из статьи; на чуть больших датасетах это всё падает по памяти, ошибки cuda
Как итог, LightGBM оказался существенно лучше на 2-м датасете, но на 1-м — почти паритет. Но какой же TabM медленный, на CPU время обучения отличается х1000раз и это на 100 эпохах, в статье предлают обучать еще больше! И всё же результат достойный, но статью имеет смысл перепроверить с точки зрения метрик
Average LGB Test AUC: 0.7659
Average TabM Test AUC: 0.7421
Average LGB Time: 0.23s
Average TabM Time (CPU amd 7700): 234.55s
Average TabM Time (gpu T4): 15.68s