Neural Collapse in deep classifiers during Terminal Phase of Training

Many of us are used to thinking of deep neural networks as entirely black boxes that elude rigorous analysis. Studying their puzzling behavior during training with gradient descent (e.g. deep double descent and “grokking” [Pow22G] and investigating the data processing properties of trained networks is so complicated that entire fields of research have emerged dedicated to it.

However, intriguing recent work on a topic called Neural Collapse has demonstrated that the last two layers of trained deep classifiers are not only understandable but take a remarkably simple and restricted form. In fact, with some assumptions and simplifications, their dynamics can be analyzed in close form as training proceeds.

The papers that we will describe below are quite technical but their results are very well summarized by this gif and by the figure. The phenomenon of neural collapse happens as the training of a neural network on a classification task enters its final stage. This happens when the classification error on the training set reaches zero but the loss keeps decreasing (cross-entropy or alternatives like MSE). This regime has been dubbed Terminal Phase of Training (TPT).

During TPT, as the network converges to minimal loss, the activations of the penultimate layer (evaluated on training data) and the weights of the last (classifying) layer converge to a very simple form, which is characterized by the following properties

  1. The within-class variance of the activations (small blue dots) tends to zero
  2. The means of the activations (large blue dots) are positioned on a (rotated) Equiangular Tight Frame (ETF, shown as large green dots for reference) - they form C vectors all having the same norm, arranged in such a way that the angular distance between any of them is the same and maximal.
  3. The weights of the last layer, being a linear classifier, align with the same ETF as the activations.

There are several theoretical arguments that give insight into why the training results look this way. The first paper mentioning this phenomenon [Pap20P] shows that, given that the blue dots are arranged in this way, the red dots will align themselves accordingly if only the last layer is trained.

There is an even stronger statement, proven in [Sou18I], stating that if the small blue dots are linearly separable, the red dots will align themselves with their means (the large blue dots). The follow up work studying classifiers trained with MSE loss, “Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path” [Han22N] performed a rigorous analysis of the dynamics of the last layers of a NN (in isolation from the rest) under gradient descent, showing the emergence of neural collapse.

However, neural collapse in deep networks is not a mathematically proven statement, as all aforementioned sources analyze the last layers in isolation (a concept sometimes called “layer peeling” [Fan21E]. To my knowledge, making a rigorous statement about the convergence point of a full, deep network is rather difficult and has not been done so far.

The existence of such a simple structure, essentially governed by a single parameter (the size of the frame) emerging from training neural networks with gradient descent is not only interesting in itself but can provide insight into transfer learning, calibration, design of new losses and much more. We will come back to it in future posts.

References

  • Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets, Alethea Power, Yuri Burda, Harri Edwards, Igor Babuschkin, Vedant Misra. arXiv:2201.02177 [cs] (2022)
  • Prevalence of Neural Collapse during the terminal phase of deep learning training, Vardan Papyan, X. Y. Han, David L. Donoho. Proceedings of the National Academy of Sciences (2020)
  • The Implicit Bias of Gradient Descent on Separable Data, Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, Nathan Srebro. Journal of Machine Learning Research (2018)
  • Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path, X. Y. Han, Vardan Papyan, David L. Donoho. (2022)
  • Exploring Deep Neural Networks via Layer-Peeled Model: Minority Collapse in Imbalanced Training, Cong Fang, Hangfeng He, Qi Long, Weijie J. Su. arXiv:2101.12699 [cs, math, stat] (2021)

In this series