Longer than Chinchilla
Language ModelsIn large language models pretraining, it takes a massive computing budget for every single training run.
the Chinchilla optimal bounds were proposed in the paper An empirical analysis of compute-optimal large language model training. A very common misunderstanding about Chinchilla scaling law is that it seems to impose an upper bound of the amount of token one should train for given a fixed parameter count. But it really is about the optimal tradeoff between the token amount and the model size, given a fixed computing budget. In practice, it might give a good reference number of tokens, but a general rule of thumb is still to train for as many tokens as possible before the training loss or eval loss starts to diverge.
For billion-parameter models, training a model beyond the Chinchilla bound is usually slow and costly. For example, to train a 3B model on 1T tokens with 512 A-100 with 30% MFU might take 4-5 days. Such experiences are rare and there are a lot of things that could go wrong when training on an extensive amount of tokens. To throw a few discrete data points into the sparse experiences that people have, here are some interesting failures I have experienced beyond the Chinchilla bound.
760M with weight decays 10% of lr
Chinchilla optimal size for a 760M model is somewhere around 10-15B tokens. This is a very small amount to train an LM from scratch. Practically, we almost have to go beyond Chinchilla. The model specs are the following:
parameter count | context length | batch size | learning rate |
---|---|---|---|
760M | 8192 | 8M tokens | 0.00025 |
The training appeared to converge after 30k steps but unfortunately diverged at 50k (~400B tokens). There are many possible reasons. We have found that the weight decay is a subtle parameter and usually contributes to divergence after a large amount of tokens.
But in this particular case, I would throw another theory out there:
- my learning rate was probably not decaying at the right rate.
We can see that at the divergent step (50k), the cosine lr decay is half-way through at about 1.7e-4. There are some other runs getting smooth landing with faster lr decay at below 1e-4. But it is just a theory, and there are also some smoother landing with lower weight decay. Wish I could perform a sweep to confirm which one is the case. But to get to 50k step (~400B tokens) is very costly.
3B with decoupled-lionw
Two of them are normal runs. But let us pay attention to the red curve.
This is a 3B model diverged at a very early stage (~60k steps). Model specs:
parameter count | context length | batch size | learning rate |
---|---|---|---|
3B | 8192 | 8M tokens | 0.000256 |
The optimizer is LionW and the weight decay is decoupled, meaning it is a fixed number instead of a quantity proportion to the learning rate. In this case, I used 0.0000256, 10% of the max learning rate. This was originally planned to be an ambitious run (trillion tokens) and it has a very slow warmup. At the divergence point, it has not even finished the warm-up! As a result, I am inclined to attribute the most likely cause to the decoupled weight decay as it was too high at warm-up period.
Weight decay adds an extra term in the loss function about the l2-norm of the parameters: $$ l´ = l + c|| \theta ||^2 $$ where $l$ is the original loss and $l´$ the $L2$-regularized loss.
Let us take a look at what happened to the $L2$-norm: The optimization goal interpolates between the language model loss and the model weights. Despite the general hope that we could decrease both through gradient descent (provable for convex loss), reality showed us that both language model loss and the model weight went higher.
My overall theory is that, the bad calibration of learning rate leads the training to be trapped inside a sharp ditch in the loss landscape. The loss jumps higher by bumping against walls and the cliff gets sharper (higher and higher gradient norms). The chart of the l2-norm of the gradient also support this: The gradient clipping is set up at 1.0. Despite the gradient clipping, it is pretty obvious that the gradient norm wanted to curve up, showing steeper and steeper landscape.
13B, mostly normal but…
The last one is not about a divergence, but a crazy pop of loss which miraculously recovered after 100 steps. It happens during learning rate annealing: It appears completely innocent, and recovers after about 100 training steps. A common theory is that it hits a thin wall in the loss landscape. I do agree with this, as I do not find other anomalies in the training logs. It is simply fascinating to see it was able to find its way back and pretended (almost) nothing happened.
It is worth noting that similar loss spikes have been shown in a few papers of LLM trainings such as LLaMA (though far less severe than this).