Linear attention is (maybe) all you need (to understand Transformer optimization)

Abstract

Transformer training is notoriously difficult, requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training transformers by carefully studying a simple yet canonical linearized shallow transformer model. Specifically, we train linear transformers to solve regression tasks, inspired by J. von Oswald et al. (ICML 2023), and K. Ahn et al. (NeurIPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized transformer model could actually be a valuable, realistic abstraction for understanding transformer optimization.

Publication
ICLR 2024, short version at NeurIPS 2023 Workshop on Mathematics of Modern Machine Learning (Oral)
Minhak Song
Minhak Song
Undergraduate Student

I am an undergraduate student at KAIST. I am interested in building the theoretical foundations of machine learning, deep learning, and reinforcement learning through the lens of optimization theory and statistics.