Continued Pre-Training

Goal: To ingest new knowledge in LLMs

(Ibrahim et al., 2024) performed extensive experiments to derive insights into the strategies

  • re-warming and re-decaying the learning rate: exact same learning rate schedule that was used during the initial pretraining stage
  • Adding a small portion (e.g., 5%) of the original pretraining data to the new dataset to prevent catastrophic forgetting. Note that smaller fractions like 0.5% and 1% were also effective

continued pretraining Fig. A schedule for continued pretraining from (Raschka, 2023)

Weight Averaging

Weight averaging involves averaging a single model’s weights (parameters) at different points in its training process. Typically, it’s done towards the end of the training when the model has nearly converged.

Stochastic Weight Averaging (SWA)

(Izmailov et al., 2019) where we decay an initially large learning rate, and weights are averaged over several iterations during periods of decayed (but still relatively high) learning rates. Since a model’s training trajectory can be uneven, the strategy is to average the models towards the end of the training when the learning rate is low (if a scheduler is used), where the training is nearing convergence. A modified learning rate schedule allows SGD (or other optimizers such as Adam) to continue to bounce around the optimum and explore diverse models instead of simply converging to a single solution.

SWA Fig. Stochastic Weight Averaging (SWA) averages a model’s weights towards the end of the training cycle. Image Credits (Raschka, 2023)

Intuition of why SWA works? (From first author of (Izmailov et al., 2019))

Averaging works in cases where the parameters you are averaging are oriented around a local region of low loss, helping to move the solution to a more centred “flat” solution. A flat solution, at a high level, corresponds to parameters that can be perturbed without significantly increasing the loss. These solutions generalize better, in short, because they provide a better compression of the data: they can be represented with fewer bits of precision. It’s actually quite intuitive!

The bigger the model class, the larger these regions of low loss will be, because there are many more parameters that will be consistent with the data, and therefore the greater opportunities to benefit from averaging. Really big models are also hardly trained to completion, so averaging many checkpoints, or across multiple fine tuning runs, can help find more refined solutions more quickly. This is very related to a procedure called SWALP (Yang et al., 2019), which provides training in low precision by combining weights that have been rounded up with those that have been rounded down. Model soups (Wortsman et al., 2022) works for exactly the same reason as SWA (and was very inspired by SWA) - the fine tuning runs each have the same warm start and end up exploring a local region of space.

Theory of linear mode connectivity of model weights (Juneja et al., 2023) suggests that models that start from a similar position, or are fine-tuned in similar ways, end up in the same “region” of loss space, and linearly moving between them can usually get you a similar good (if not better) model, which is similar to SWA.

Latest Weight Averaging (LaWA)

(Kaddour, 2022) demonstrated that averaging the weights of the latest k checkpoints, each taken at the end of an epoch, can expedite training progress in terms of loss and accuracy by several epochs. (Sanyal et al., 2023) explored a modified version of LaWA with higher learning rates and an earlier start in averaging checkpoints during training. The researchers found that this approach significantly outperformed standard SWA and EMA techniques.

WARM

(Ramé et al., 2024) aims to enhance the RLHF for LLMs. Specifically, the researchers attempt to mitigate reward hacking in LLMs by averaging the weights of finetuned reward models. Reward hacking occurs when an LLM learns to manipulate or exploit its reward system’s flaws to attain high scores or rewards, without genuinely fulfilling the intended task or achieving the essential objectives. WARM proposes to average the weight of multiple RMs. They use a simple linear average as in stochastic weight averaging. The difference, however, is that the models are not sampled from the same trajectory but are independently created from the pretrained model, as in Model ratatouille. Alternatively, WARM also has a so-called Baklava procedure to sample along a fine-tuning trajectory.

WARM in RLHF Fig. An outline of how WARM is used in the RLHF process. The only new aspect here is that the method uses a reward model from weight averaging instead of training a single reward modeling (annotated figure from WARM paper. Image Credits (Raschka, 2023)

ratatouille comparison Fig. A comparison between the different model merging and averaging methods. Image Credits (Raschka, 2023)

Model Merging

Model merging involves combining multiple different trained models into a single model. Model Ratatouille (Ramé et al., 2023) proposes to reuse multiple fine-tuned iterations of the identical base model across various diverse auxiliary tasks.

Model Soups (Wortsman et al., 2022) averaging the weights of multiple models finetuned with different hyperparameter configurations often improves accuracy and robustness.” The key difference between this and an ensemble is that no inference time penalty is incurred.

Model Ratatouille Fig. The Model Ratatouille method for model merging. Image Credits (Raschka, 2023)

Mixture of Experts (MoE)

Mixture of Experts, is a type of ensemble model that combines several smaller expert subnetworks. Each subnetwork is responsible for handling different types of tasks or, more concretely, tokens. allocate computational resources more efficiently.

Switch Transformers

(Fedus et al., 2022)

Switch Transformers Fig. Annotated Figure of Switch Transformers. Image Credits (Raschka, 2023)

Mixtral 8x7B

(Jiang et al., 2024) is a sparse MoE, performing similar to Llama 2 70B. Uses 8 experts in-total, and 2 experts per token combining Mistral 7B model. The total number of parameters are 47B (not 56B, since only FFNs are copied) The router reroutes the tokens such that only 13B (<14B) parameters (2x <7B, instead of all <56B) are used at a time for the forward pass, so the training (and especially inference) will be faster compared to the traditional non-MoE approach.

Architecture: Replaces each feed-forward module in a transformer architecture with 8 expert layers.

Annotated Transformers Architecture Fig. Annotated Transformers Architecture. Image Credits (Raschka, 2023)

Routing module (also known as a gating network or \(G\)) computes the output as \(\sum_{i=1}^8 G(x)_i. E_i(x)\) where \(E_i\) are the expert outputs. At first glance, it might seem like Mixtral is simply adding additional parameters to an LLM via these expert (feed-forward) modules to represent a sort of weighted ensemble approach. However, there’s an additional tweak: Mixtral is a sparse MoE, which means that only a subset of the experts are used for each input (TopK=2), i.e., \(G(x) := Softmax(TopK(x.W_g))\).

Mixtral of Experts Fig. Annotated figure from Mixtral of Experts paper explaining the MoE module. Image Credits (Raschka, 2023)

Expert Specialization: consecutive tokens in text datasets are often assigned to the same experts. Additionally, indentation tokens in Python code are frequently assigned to the same expert

Proxy Tuning

(Liu et al., 2024) Proxy-tuning works through a straightforward process at the decoding stage by adjusting the logits of the target LLM. Specifically, it involves calculating the difference in logits between a smaller base model and a finetuned model. This difference is then added to the logits of the target model.

proxy-tuning Fig. Annotated illustration of proxy-tuning. Image Credits (Raschka, 2023)

Benefits:

  • It might outperform LoRA in certain contexts,
  • It’s useful when the large base model is a “black box”, and its internal weights are inaccessible.

However, the smaller models must share the same vocabulary as the larger target model.

References

  1. Ibrahim, A., Thérien, B., Gupta, K., Richter, M. L., Anthony, Q., Lesort, T., Belilovsky, E., & Rish, I. (2024). Simple and Scalable Strategies to Continually Pre-train Large Language Models. https://arxiv.org/abs/2403.08763
  2. Raschka, S. (2023). Noteworthy AI Research Papers of 2024 (Part One) [Blog]. https://magazine.sebastianraschka.com/p/ai-research-papers-2024-part-1
  3. Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., & Wilson, A. G. (2019). Averaging Weights Leads to Wider Optima and Better Generalization. https://arxiv.org/abs/1803.05407
  4. Raschka, S. (2023). Model Merging, Mixtures of Experts, and Towards Smaller LLMs [Blog]. https://magazine.sebastianraschka.com/p/ai-research-papers-2024-part-1
  5. Yang, G., Zhang, T., Kirichenko, P., Bai, J., Wilson, A. G., & Sa, C. D. (2019). SWALP : Stochastic Weight Averaging in Low-Precision Training. https://arxiv.org/abs/1904.11943
  6. Wortsman, M., Ilharco, G., Gadre, S. Y., Roelofs, R., Gontijo-Lopes, R., Morcos, A. S., Namkoong, H., Farhadi, A., Carmon, Y., Kornblith, S., & Schmidt, L. (2022). Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time. https://arxiv.org/abs/2203.05482
  7. Juneja, J., Bansal, R., Cho, K., Sedoc, J., & Saphra, N. (2023). Linear Connectivity Reveals Generalization Strategies. https://arxiv.org/abs/2205.12411
  8. Kaddour, J. (2022). Stop Wasting My Time! Saving Days of ImageNet and BERT Training with Latest Weight Averaging. https://arxiv.org/abs/2209.14981
  9. Sanyal, S., Neerkaje, A., Kaddour, J., Kumar, A., & Sanghavi, S. (2023). Early Weight Averaging meets High Learning Rates for LLM Pre-training. https://arxiv.org/abs/2306.03241
  10. Ramé, A., Vieillard, N., Hussenot, L., Dadashi, R., Cideron, G., Bachem, O., & Ferret, J. (2024). WARM: On the Benefits of Weight Averaged Reward Models. https://arxiv.org/abs/2401.12187
  11. Ramé, A., Ahuja, K., Zhang, J., Cord, M., Bottou, L., & Lopez-Paz, D. (2023). Model Ratatouille: Recycling Diverse Models for Out-of-Distribution Generalization. https://arxiv.org/abs/2212.10445
  12. Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. https://arxiv.org/abs/2101.03961
  13. Jiang, A. Q., Sablayrolles, A., Roux, A., Mensch, A., Savary, B., Bamford, C., Chaplot, D. S., de las Casas, D., Hanna, E. B., Bressand, F., Lengyel, G., Bour, G., Lample, G., Lavaud, L. R., Saulnier, L., Lachaux, M.-A., Stock, P., Subramanian, S., Yang, S., … Sayed, W. E. (2024). Mixtral of Experts. https://arxiv.org/abs/2401.04088
  14. Liu, A., Han, X., Wang, Y., Tsvetkov, Y., Choi, Y., & Smith, N. A. (2024). Tuning Language Models by Proxy. https://arxiv.org/abs/2401.08565