Neural Collapse in deep classifiers during Terminal Phase of Training

Neural collapse refers to the observation that the last two layers of neural networks that were trained for a long time take a very simple, restricted and universal form - namely an equiangular tight frame. Papers referenced in this pill highlight the phenomenon and provide theoretical hints for why it happens.

Many of us are used to thinking of deep neural networks as entirely black boxes that elude rigorous analysis. Studying their puzzling behavior during training with gradient descent (e.g. deep double descent and “grokking” [Pow22G] and investigating the data processing properties of trained networks is so complicated that entire fields of research have emerged dedicated to it.

However, intriguing recent work on a topic called Neural Collapse has demonstrated that the last two layers of trained deep classifiers are not only understandable but take a remarkably simple and restricted form. In fact, with some assumptions and simplifications, their dynamics can be analyzed in close form as training proceeds.

The papers that we will describe below are quite technical but their results are very well summarized by this gif and by Figure 1. The phenomenon of neural collapse happens as the training of a neural network on a classification task enters its final stage. This happens when the classification error on the training set reaches zero but the loss keeps decreasing (cross-entropy or alternatives like MSE). This regime has been dubbed Terminal Phase of Training (TPT).

During TPT, as the network converges to minimal loss, the activations of the penultimate layer (evaluated on training data) and the weights of the last (classifying) layer converge to a very simple form, which is characterized by the following properties

  1. The within-class variance of the activations (small blue dots) tends to zero
  2. The means of the activations (large blue dots) are positioned on a (rotated) Equiangular Tight Frame (ETF, shown as large green dots for reference) - they form C vectors all having the same norm, arranged in such a way that the angular distance between any of them is the same and maximal.
  3. The weights of the last layer, being a linear classifier, align with the same ETF as the activations.

There are several theoretical arguments that give insight into why the training results look this way. The first paper mentioning this phenomenon [Pap20P] shows that, given that the blue dots are arranged in this way, the red dots will align themselves accordingly if only the last layer is trained.

There is an even stronger statement, proven in [Sou18I], stating that if the small blue dots are linearly separable, the red dots will align themselves with their means (the large blue dots). The follow-up work studying classifiers trained with MSE loss, “Neural Collapse Under MSE Loss: Proximity to and Dynamics on the Central Path” [Han22N] performed a rigorous analysis of the dynamics of the last layers of a NN (in isolation from the rest) under gradient descent, showing the emergence of neural collapse.

However, neural collapse in deep networks is not a mathematically proven statement, as all aforementioned sources analyze the last layers in isolation (a concept sometimes called “layer peeling” [Fan21E]. To my knowledge, making a rigorous statement about the convergence point of a full, deep network is rather difficult and has not been done so far.

The existence of such a simple structure, essentially governed by a single parameter (the size of the frame) emerging from training neural networks with gradient descent is not only interesting in itself but can provide insight into transfer learning, calibration, design of new losses and much more. We will come back to it in future posts.

References

In this series