Training Large Language Models to Reason in a Continuous Latent Space

Author

Kevin Jablonka

Published

December 14, 2024

Making LLMs better at reasoning

A lot of research currently goes into making LLMs better at reasoning. Much of this is linked to the fact that current systems “think” with the “same intensity” for every token they produce. Many think that systems would be better if models could “think harder” for harder tasks.

Ilya Sutskever at NeurIPS 2024. Reasoning is very prominent on many research agendas. Full video on YouTube

Chain of Thought (CoT)

Chain of thought prompting is a surprisingly simple method that has shown to improve the performance of LLMs on various benchmarks. Effectively, Wei (Wei et al. 2022) showed in their paper (almost 10k citations) that thought demonstrations improve performance.

Kojima (around 3.5k citations) (Kojima et al. 2022) showed that simply adding “Let’s think step by step” into the prompt can yield comparable performance boosts.

Note

Various CoT variants have been proposed. I think Lilian Weng’s take that many of those things should be blog posts and not papers is true.

Some useful variants are shown in (Fu et al. 2023). The paper also shows some of the prompt sensitivity (e.g., changing Q: to Question: or using \n instead of . to separate steps.)

Some follow up work as been discussing that one explanation for the improved performance via CoT reasoning is that the effective compute increases.

Note

CoT can also be thought of as one flavor of test-time compute. This is currently one of the most promising and active streams of research to increase the performance of LLMs. Models like o1 (Qin et al. 2024) heavily lean on utilizing test-time compute (i.e. “thinking” more inference time - ideally, making the amount of thinking proportional to the difficulty of the question).

Internalized CoT

Since generating extra tokens for reasoning in inference time is expensive, researchers attempted to internalize the reasoning pathway.

Tip

To converge training, the authors found that it is important to reset optimizer state. Optimizers such as AdamwW keep running averages - and those cause problems when the loss function suddenly changes.

CoT traces are not faithful and might be limiting

A tweet by one of the godfathers. Reasoning might not need to be verbalized.

It is well known, that CoT traces are not faithful., Turpin showed this by adding biases into the prompt. Those biases led to drops in model performance but were not mentioned by the model. This experiments directly allows to conclude that the verbalized explanations are not faithful as a reason for change in predictions is not verbalized.

Some anthropomorphize this by linking this to neuroscience results that show that “Language is primarily a tool for communication rather than thought” (Fedorenko, Piantadosi, and Gibson 2024)

Roles of the language network. Taken from (Fedorenko, Piantadosi, and Gibson 2024). Subfigure b shows that the language network is not strongly activated for non-linguistic tasks.

In addition, it is notable that CoT restricts models to one discrete reasoning path. However, it might be effective to explore multiple paths. A relevant work that does this is Tree of Thoughts (ToT) (Yao et al. 2023). ToT works by creating a branching structure of multiple potential solutions:

  1. Similar to CoT it breaks down problems into sequential thought steps
  2. But it generates multiple alternative thoughts at each step
  3. This can be used to create a tree-like structure of reasoning paths
  4. These trees can be explored using either:
    • Breadth-first search (BFS)
    • Depth-first search (DFS)
  5. Each node (state) can be evaluated using a classifier or majority vote

Methods

The key idea presented in the paper is to not verbalize “thoughts” as language tokens but instead use the hidden representations that the model produces as richer “thought vectors” that could, in principle, also encode multiple reasoning pathways at the same time.

Method proposed by Coconut compared to CoT.

That is, the approach continues to “think” in its internal representation but only verbalized the final answer.

Training approach

To train this model, the authors use a protocol that is inspired by the one utilized for internalized CoT: Over multiple steps, they replace verbalized thinking steps with latent ones.

Training protocol used by the authors.

Results

In their benchmarks, the authors observed promising results for their approach. They showed that their approached outperforms other “internalized thought” techniques.

Method GSM8k ProntoQA ProsQA
Acc. (%) # Tokens Acc. (%) # Tokens Acc. (%) # Tokens
CoT \(42.9 \pm 0.2\) 25.0 \(98.8 \pm 0.8\) 92.5 \(77.5 \pm 1.9\) 49.4
No-CoT \(16.5 \pm 0.5\) 2.2 \(93.8 \pm 0.7\) 3.0 \(76.7 \pm 1.0\) 8.2
iCoT \(30.0^*\) 2.2 \(99.8 \pm 0.3\) 3.0 \(98.2 \pm 0.3\) 8.2
Pause Token \(16.4 \pm 1.8\) 2.2 \(77.7 \pm 21.0\) 3.0 \(75.9 \pm 0.7\) 8.2
Coconut (Ours) \(34.1 \pm 1.5\) 8.2 \(99.8 \pm 0.2\) 9.0 \(97.0 \pm 0.3\) 14.2
- w/o curriculum \(14.4 \pm 0.8\) 8.2 \(52.4 \pm 0.4\) 9.0 \(76.1 \pm 0.2\) 14.2
- w/o thought \(21.6 \pm 0.5\) 2.3 \(99.9 \pm 0.1\) 3.0 \(95.5 \pm 1.1\) 8.2
- pause as thought \(24.1 \pm 0.7\) 2.2 \(100.0 \pm 0.1\) 3.0 \(96.6 \pm 0.8\) 8.2

They also found that increasing the number of latent thoughts per thinking step increases performance.

By decoding the hidden thoughts they could assign probabilities to different options that COCONUT “explored”. This can be used to construct search trees.

Search tree proposed for ProsQA.

The trees are a bit more “anecdotal” (as there are no systematic statistics) but an interesting perspective on the results

Conclusions

  • The current protocol is computationally expensive in training (requires multiple forward passes). If one would like to do this on scale, development of suitable infrastructure is needed.
  • It is nice to explore some new paradigms (with smaller models)
  • Some of this also links to agents (instead of letting them talk via text we could also used hidden representation)
Fedorenko, Evelina, Steven T. Piantadosi, and Edward A. F. Gibson. 2024. “Language Is Primarily a Tool for Communication Rather Than Thought.” Nature 630 (8017): 575–86. https://doi.org/10.1038/s41586-024-07522-w.
Fu, Yao, Hao Peng, Ashish Sabharwal, Peter Clark, and Tushar Khot. 2023. “Complexity-Based Prompting for Multi-Step Reasoning.” https://arxiv.org/abs/2210.00720.
Guo, Jeff, and Philippe Schwaller. 2024. “It Takes Two to Tango: Directly Optimizing for Constrained Synthesizability in Generative Molecular Design.” arXiv. https://doi.org/10.48550/ARXIV.2410.11527.
Kojima, Takeshi, Shixiang Shane Gu, Machel Reid, Yutaka Matsuo, and Yusuke Iwasawa. 2022. “Large Language Models Are Zero-Shot Reasoners.” Advances in Neural Information Processing Systems 35: 22199–213.
Landrum, Gregory A., Jessica Braun, Paul Katzberger, Marc T. Lehner, and Sereina Riniker. 2024. “Lwreg: A Lightweight System for Chemical Registration and Data Storage.” Journal of Chemical Information and Modeling 64 (16): 6247–52. https://doi.org/10.1021/acs.jcim.4c01133.
Orsi, Markus, and Jean-Louis Reymond. 2024. “One Chiral Fingerprint to Find Them All.” Journal of Cheminformatics 16 (1): 53.
Qin, Yiwei, Xuefeng Li, Haoyang Zou, Yixiu Liu, Shijie Xia, Zhen Huang, Yixin Ye, et al. 2024. “O1 Replication Journey: A Strategic Progress Report – Part 1.” https://arxiv.org/abs/2410.18982.
Wei, Jason, Xuezhi Wang, Dale Schuurmans, Maarten Bosma, Fei Xia, Ed Chi, Quoc V Le, Denny Zhou, et al. 2022. “Chain-of-Thought Prompting Elicits Reasoning in Large Language Models.” Advances in Neural Information Processing Systems 35: 24824–37.
Yao, Shunyu, Dian Yu, Jeffrey Zhao, Izhak Shafran, Thomas L. Griffiths, Yuan Cao, and Karthik Narasimhan. 2023. “Tree of Thoughts: Deliberate Problem Solving with Large Language Models.” https://arxiv.org/abs/2305.10601.