tgoop.com/tensorflowblog/458
Last Update:
SALSA (Stable Armijo Line Search Adaptation) β ΠΌΠ΅ΡΠΎΠ΄, ΡΠ°Π·ΡΠ°Π±ΠΎΡΠ°Π½Π½ΡΠΉ Π΄Π»Ρ ΠΎΠΏΡΠΈΠΌΠΈΠ·Π°ΡΠΈΠΈ Learning Rate (LR) Π²ΠΎ Π²ΡΠ΅ΠΌΡ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ.
ΠΡΠ½ΠΎΠ²Π½Π°Ρ ΠΊΠΎΠ½ΡΠ΅ΠΏΡΠΈΡ ΠΌΠ΅ΡΠΎΠ΄Π° ΠΏΠΎΡΡΡΠΎΠ΅Π½Π° Π²ΠΎΠΊΡΡΠ³ Π²ΡΠΏΠΎΠ»Π½Π΅Π½ΠΈΡ Π»ΠΈΠ½Π΅ΠΉΠ½ΠΎΠ³ΠΎ ΠΏΠΎΠΈΡΠΊΠ° Π΄Π»Ρ ΠΎΠΏΡΠ΅Π΄Π΅Π»Π΅Π½ΠΈΡ Π½Π°ΠΈΠ»ΡΡΡΠ΅Π³ΠΎ Π²ΠΎΠ·ΠΌΠΎΠΆΠ½ΠΎΠ³ΠΎ LR Π΄Π»Ρ ΠΊΠ°ΠΆΠ΄ΠΎΠ³ΠΎ ΡΠ°Π³Π° ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ, ΡΡΠΎ Π΄Π°Π΅Ρ Π±ΡΡΡΡΡΡ ΡΡ
ΠΎΠ΄ΠΈΠΌΠΎΡΡΡ ΠΈ ΡΠ»ΡΡΡΠ΅Π½Π½ΠΎΠ΅ ΠΎΠ±ΠΎΠ±ΡΠ΅Π½ΠΈΠ΅.
Π§ΡΠΎΠ±Ρ ΡΠΌΠ΅Π½ΡΡΠΈΡΡ Π²ΡΡΠΈΡΠ»ΠΈΡΠ΅Π»ΡΠ½ΡΡ Π½Π°Π³ΡΡΠ·ΠΊΡ, Salsa ΠΏΡΠ΅Π΄Π»Π°Π³Π°Π΅Ρ ΠΏΠΎΡΠ°Π³ΠΎΠ²ΡΠΉ ΠΌΠΈΠ½ΠΈΠ°ΡΡΡΠ½ΡΠΉ Π»ΠΈΠ½Π΅ΠΉΠ½ΡΠΉ ΠΏΠΎΠΈΡΠΊ. Π Π½Π΅ΠΌ LR ΠΏΠΎΡΡΠ΅ΠΏΠ΅Π½Π½ΠΎ ΡΠ²Π΅Π»ΠΈΡΠΈΠ²Π°Π΅ΡΡΡ Ρ ΠΊΠ°ΠΆΠ΄ΡΠΌ ΡΠ°Π³ΠΎΠΌ, Π° ΠΊΡΠΈΡΠ΅ΡΠΈΠΉ Π»ΠΈΠ½Π΅ΠΉΠ½ΠΎΠ³ΠΎ ΠΏΠΎΠΈΡΠΊΠ° ΠΏΠΎΡΡΠΎΡΠ½Π½ΠΎ ΠΏΠ΅ΡΠ΅ΠΎΡΠ΅Π½ΠΈΠ²Π°Π΅ΡΡΡ.
ΠΠΎΠΏΠΎΠ»Π½ΠΈΡΠ΅Π»ΡΠ½ΠΎ, Salsa Π²ΠΊΠ»ΡΡΠ°Π΅Ρ ΡΠΊΡΠΏΠΎΠ½Π΅Π½ΡΠΈΠ°Π»ΡΠ½ΠΎΠ΅ ΡΠ³Π»Π°ΠΆΠΈΠ²Π°Π½ΠΈΠ΅ Π² ΠΏΡΠΎΡΠ΅ΡΡ Π»ΠΈΠ½Π΅ΠΉΠ½ΠΎΠ³ΠΎ ΠΏΠΎΠΈΡΠΊΠ° ΠΈ ΡΡΡΠ°Π½Π°Π²Π»ΠΈΠ²Π°Π΅Ρ Π΄Π²Π° ΡΠΊΡΠΏΠΎΠ½Π΅Π½ΡΠΈΠ°Π»ΡΠ½ΡΡ
ΡΠΊΠΎΠ»ΡΠ·ΡΡΠΈΡ
ΡΡΠ΅Π΄Π½ΠΈΡ
Π΄Π»Ρ ΡΠΊΠΎΡΠΎΡΡΠΈ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ. ΠΡΠΎ ΠΏΠΎΠΌΠΎΠ³Π°Π΅Ρ ΡΡΠ°Π±ΠΈΠ»ΠΈΠ·ΠΈΡΠΎΠ²Π°ΡΡ ΠΎΠΏΡΠΈΠΌΠΈΠ·Π°ΡΠΈΡ ΠΈ ΡΠΌΠ΅Π½ΡΡΠΈΡΡ Π½Π΅ΡΡΠ°Π±ΠΈΠ»ΡΠ½ΠΎΡΡΡ ΠΎΡ ΠΌΠΈΠ½ΠΈ-ΠΏΠ°ΠΊΠ΅ΡΠΈΡΠΎΠ²Π°Π½ΠΈΡ.
ΠΠΊΡΠΏΠ΅ΡΠΈΠΌΠ΅Π½ΡΠ°Π»ΡΠ½ΡΠ΅ ΡΠ΅Π·ΡΠ»ΡΡΠ°ΡΡ ΠΏΠΎΠΊΠ°Π·ΡΠ²Π°ΡΡ, ΡΡΠΎ Salsa ΠΏΡΠ΅Π²ΠΎΡΡ
ΠΎΠ΄ΠΈΡ Π΄ΡΡΠ³ΠΈΠ΅ ΠΌΠ΅ΡΠΎΠ΄Ρ ΠΎΠΏΡΠΈΠΌΠΈΠ·Π°ΡΠΈΠΈ: 50% ΡΠΎΠΊΡΠ°ΡΠ΅Π½ΠΈΠ΅ final loss ΠΈ 1,25 average rank Π² ΡΠ·ΡΠΊΠΎΠ²ΡΡ
ΠΈ Π³ΡΠ°ΡΠΈΡΠ΅ΡΠΊΠΈΡ
Π·Π°Π΄Π°ΡΠ°Ρ
.
ΠΡΡΠΈΡΠ»ΠΈΡΠ΅Π»ΡΠ½ΡΠ΅ ΠΈΠ·Π΄Π΅ΡΠΆΠΊΠΈ Salsa Π²ΡΠ΅Π³ΠΎ Π½Π° 3% Π²ΡΡΠ΅, ΡΠ΅ΠΌ Ρ Π±Π°Π·ΠΎΠ²ΠΎΠ³ΠΎ LR ΠΌΠ΅ΡΠΎΠ΄Π°, ΡΡΠΎ ΠΌΠΎΠΆΠ½ΠΎ Π²ΠΎΡΠΏΡΠΈΠ½ΠΈΠΌΠ°ΡΡ ΠΊΠ°ΠΊ Π½Π΅Π·Π½Π°ΡΠΈΡΠ΅Π»ΡΠ½ΡΠΌ ΡΠ²Π΅Π»ΠΈΡΠ΅Π½ΠΈΠ΅ΠΌ, ΡΡΠΈΡΡΠ²Π°Ρ ΠΏΠΎΠΊΠ°Π·Π°ΡΠ΅Π»ΠΈ ΠΏΡΠΎΠΈΠ·Π²ΠΎΠ΄ΠΈΡΠ΅Π»ΡΠ½ΠΎΡΡΠΈ. Salsa Π΄ΠΎΡΡΠ°ΡΠΎΡΠ½ΠΎ ΡΠ½ΠΈΠ²Π΅ΡΡΠ°Π»Π΅Π½, ΡΡΠΎΠ±Ρ ΠΈΡΠΏΠΎΠ»ΡΠ·ΠΎΠ²Π°ΡΡΡΡ Ρ ΡΠ°Π·Π»ΠΈΡΠ½ΡΠΌΠΈ ΠΎΠΏΡΠΈΠΌΠΈΠ·Π°ΡΠΎΡΠ°ΠΌΠΈ, ΠΈ ΠΎΡΠΎΠ±Π΅Π½Π½ΠΎ ΡΡΡΠ΅ΠΊΡΠΈΠ²Π΅Π½ ΠΏΡΠΈ ΠΎΠ±ΡΡΠ΅Π½ΠΈΠΈ ΡΠΎΠ²ΡΠ΅ΠΌΠ΅Π½Π½ΡΡ
Π°ΡΡ
ΠΈΡΠ΅ΠΊΡΡΡ, ΠΊΠΎΡΠΎΡΡΠ΅ ΡΡΠ²ΡΡΠ²ΠΈΡΠ΅Π»ΡΠ½Ρ ΠΊ ΡΠΊΠΎΡΠΎΡΡΠΈ ΠΎΠ±ΡΡΠ΅Π½ΠΈΡ.
# Clone repository:
git clone https://github.com/TheMody/No-learning-rates-needed-Introducing-SALSA-Stable-Armijo-Line-Search-Adaptation.git
# Create & activate env:
conda env create -f environment.yml
conda activate sls3
# Install dependencies:
pip install pytorch numpy transformers datasets tensorflow-datasets wandb
# NOTE: custom optimizer is in \salsa\SaLSA.py,comparison version are in \salsa\adam_sls.py:
from salsa.SaLSA import SaLSA
self.optimizer = SaLSA(model.parameters())
# NOTE: typical pytorch forward pass needs to be changed to:
def closure(backwards = False):
y_pred = model(x)
loss = criterion(y_pred, y)
if backwards: loss.backward()
return loss
optimizer.zero_grad()
loss = optimizer.step(closure = closure)
@ai_machinelearning_big_data
#AI #LLM #ML #Train #SALSA