NL-226, Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking, COLM 2024

◼ Comment

  • STaR 논문의 후속작으로 확장한 것이다
  • STaR은 (QA 테스크같은) 정답이 비교적 정해져있는 테스크에서 reasoning 데이터를 확장하여 모델 성능을 올리는 것이다
  • Quiet-STaR 은 말그대로 조용한 STaR로 명시적으로 reasoning을 생성하고 그런것은 아니지만, LM에 think란 과정을 도입하여 성능을 향상시키는 것이다.
  • 즉 일상대화 뿐만 아니라 딱히 도메인에 상관없이 적용할 수 있는 방법이다.
  • 크게 3가지 step이 있다.
    • think
    • talk
    • learn
  • think
    • 각 토큰에서 think 시퀀스를 만들어낸다. 즉 이 부분이 quiet-reasnoing이라고 볼 수 있음
    • 5번째 토큰에서 모델을 학습시킨다고 하면 일반적으론 6번째 토큰이 생성되도록 1-5 토큰들이 입력으로 들어가는 것이다
    • 여기서는 1-5 토큰을 모고 think token을 생성한다. 
    • 즉 1-5번째 토큰뒤에 생성될 수 있는 토큰들을 생성하는 것인데, special token을 넣어서 이 부분은 think 하는 부분이라고 모델에게 알려준다
    • think sequence을 하나만 생성하진 않고 여러개 생성해서 샘플링해서 사용한다고 한다
  • talk
    • 1-5번째 토큰들을 X라고 하고 think sequence T라고 해보자
    • X와 T을 결합해서 6번째 토큰을 예측할 수 있고
    • X만으로 6번째 토큰을 예측할 수 있다 (기존방법)
    • 이 2가지 예측확률을 weighted sum해서 mixed 확률을 계산한다
    • weight는 MLP을 이용해서 예측한다고 한다
    • 여기서는 6번째 토큰만을 예측하도록 학습하지 않고 ntrue라고해서 뒤에 3개(예시)의 미래토큰을 예측하도록 학습시킨다.
    • 즉 6,7,8번째 토큰까지 한꺼번에 예측하도록 학습한다고 한다.
    • 4.4.2에 왜 그러는지 간단히 이유를 간단히 말하는데 정확하게 이해한건 X
  • learn
    • mixed 확률로 loss 계산하는게 NLL loss
    • 이를 이용하여 reward을 계산해서 T가 생성되는 부분을 학습하는게 REINFORCEMENT loss
    • 이 2가지 loss을 더해서 최종 학습을 한다
  • 단점으로는 딱 봐도 계산량이 너무 많이 들어 보인다
    • 또한 주어진 시퀀스에 동적으로 think을 컨트롤할 수 없다
    • 항상 think을 하게 하고 결합할지 말지를 결정하는 것이기 때문에 한계가 있는거 같음
    • 7B 모델에서만 실험해서 더 큰 모델에서 어떨지 궁금하긴하다
    • 좀 더 자세한건 limitation 참고

Abstract

  • 사람들이 글을 쓰거나 말을 할 때, 때때로 생각하기 위해 잠시 멈추는 경우가 있습니다. 
  • 전통적으로 논리적 사고는 질문에 답하거나 과제를 완수하는 방법으로 여겨졌지만, 실제로는 거의 모든 글에 암묵적으로 논리적 사고가 포함됩니다. 
  • 예를 들어, 증명 사이에 명시되지 않은 단계들이나 대화에서 타인의 마음을 추론하는 과정이 여기에 해당됩니다. 
  • 2022년 Zelikman 등이 제안한 Self-Taught Reasoner(STaR)에서는 질문-답변 형식의 몇 가지 예시를 통해 유용한 사고 방식을 배우고, 올바른 답을 이끄는 이유를 추론하는 방식을 채택했습니다. 
  • 그러나 이는 매우 제한된 환경입니다. 
  • 이상적으로는 언어 모델이 다양한 텍스트에서 암묵적인 논리를 추론할 수 있어야 합니다.
  • 우리는 STaR을 일반화한 Quiet-STaR을 제안합니다. 
  • 이 모델에서는 언어 모델이 향후 텍스트를 설명하기 위해 각 토큰에서 이유를 생성함으로써 예측 성능을 향상시킵니다. 
  • 이 과정에서 몇 가지 주요 과제가 있습니다: 
    • 1) 텍스트의 연속적인 생성을 처리하는 데 드는 계산 비용, 
    • 2) 언어 모델이 처음에는 내부 사고를 생성하거나 사용하는 방법을 모른다는 점, 
    • 3) 개별 다음 토큰을 넘어선 예측이 필요하다는 점. 
    • 이를 해결하기 위해 우리는 토큰 단위 병렬 샘플링 알고리즘을 제안하고, 사고의 시작과 끝을 나타내는 학습 가능한 토큰과 확장된 teacher-forcing 기법을 사용했습니다.
  • 고무적인 점은 생성된 이유들이 예측하기 어려운 토큰에 특히 도움이 되어 모델이 어려운 질문에 직접 답할 수 있는 능력을 향상시킨다는 것입니다. 
  • 특히, Quiet-STaR을 사용하여 인터넷 텍스트 코퍼스에 대해 언어 모델을 추가 학습한 후 GSM8K에서 5.9%에서 10.9%로, CommonsenseQA에서 36.3%에서 47.2%로 향상된 제로샷 성능을 보였으며, 자연 텍스트에서 어려운 토큰의 퍼플렉시가 개선되었습니다. 
  • 중요한 점은 이러한 성능 향상이 이러한 과제에 대한 별도의 미세 조정 없이 이루어졌다는 것입니다. 
  • Quiet-STaR은 언어 모델이 더 일반적이고 확장 가능한 방식으로 논리적 사고를 학습하는 데 한 걸음 나아간 방법입니다.

1 Introduction

  • 텍스트의 많은 의미는 행간에 숨겨져 있습니다. 
  • 문서에서 왜 특정 문장이 등장하는지를 이해하지 못하면 독자는 얕은 수준의 이해에 그치게 됩니다. 
  • 이는 상식 추론, 정리 증명, 프로그래밍과 같은 다양한 과제에서 언어 모델(LM)에도 해당된다는 것이 반복적으로 증명되었습니다 (Wei et al., 2022b; Nye et al., 2021; Zelikman et al., 2022; 2023a; Kojima et al., 2022). 
  • 텍스트의 암시를 추론하여 이후의 텍스트를 예측하는 것은 다양한 과제에서 언어 모델 성능을 꾸준히 향상시키는 것으로 나타났지만, LM이 그 자체의 추론으로부터 학습할 수 있도록 하는 방법들(Zelikman et al., 2022 등)은 개별 과제나 사전에 정의된 일련의 과제를 해결하는 데 중점을 두었습니다 (Wei et al., 2021b). 
    • 이러한 연구들은 특정 추론 작업을 제공하거나 경우에 따라서는 추론 자체를 제공하는 정교하게 조정된 데이터셋에 의존합니다. 
  • 반면 우리는, 모든 텍스트에 암시적인 추론이 존재한다면, 왜 언어 모델링 작업을 활용해 추론을 가르치지 말아야 하는지 질문합니다.
  • 특히, Self-Taught Reasoner(STaR, Zelikman et al. 2022)는 언어 모델이 질문-답변(QA) 데이터셋에서 추론 능력을 부트스트래핑할 수 있음을 보여주었습니다. 
    • 모델이 질문에 답하려고 이유(rationale)를 샘플링하고, 그 이유가 올바른 답을 도출할 경우 학습한 다음, 더 어려운 문제를 해결하기 위해 이를 반복하는 방식입니다. 
    • 그러나 큐레이팅된 QA 데이터셋으로 학습하는 것은 그 이유의 범위와 일반성을 제한합니다. 
    • 고품질의 QA 데이터셋은 신중한 큐레이션이 필요하며, 본질적으로 특정 추론 작업의 일부만을 다룹니다. 
    • 따라서 우리는 STaR을 확장하여 LM이 수학적 QA와 같은 특정 과제에서 추론하는 것이 아니라, 대규모 인터넷 텍스트 코퍼스에서 향후 텍스트를 추론하는 데 도움이 되는 추론을 생성하도록 학습시킵니다. 
    • STaR은 정답이 있어야하는 문제같은 경우(Q&A 데이터)에만 확장이 가능한거고, Quiet-STaR은 이를 일반화 시켜 일반적인 텍스트를 생성하는데도 추론이 가능하도록 하는 것
    • 즉 암시적인 추론을 LM에 가르친다
  • 이를 통해 언어에 존재하는 다양한 작업으로부터 LM이 학습할 수 있게 됩니다 (Weber et al., 2021). 
  • 이는 현재의 언어 모델링 패러다임에서 중요한 직관에 기반하며, 즉 "언어 모델은 비지도 다중 작업 학습자"라는 개념입니다 (Radford et al., 2019). 
  • STaR에서와 같이 우리는 LM의 기존 추론 능력을 활용하여 이유를 생성하고, REINFORCE 기반 보상(Williams, 1992)을 사용해 LM을 학습시킵니다. 
  • 우리는 이 기법을 Quiet-STaR이라고 부르며, 이는 STaR을 "조용히" 적용하여 모델이 말하기 전에 생각하도록 학습하는 것으로 이해할 수 있습니다.
  • 전반적으로 Quiet-STaR은 각 토큰 뒤에 이유를 생성하여 향후 텍스트를 설명하도록 하며(생각), 이유를 포함한 예측과 그렇지 않은 예측을 섞어 학습하고(말하기), REINFORCE를 사용하여 더 나은 이유를 생성하도록 학습합니다(학습). 
  • 우리는 Quiet-STaR을 Mistral 7B (Jiang et al., 2023)에 적용하고, 웹 텍스트 데이터셋인 OpenWebMath (Paster et al., 2023)와 Colossal Clean Crawled Corpus (C4, Raffel et al., 2020)를 사용했습니다. 
  • 데이터셋별 미세 조정 없이도 Quiet-STaR은 CommonsenseQA에서 제로샷 추론 능력이 36.3%에서 47.2%로, GSM8K에서 5.9%에서 10.9%로 개선되었으며, 이러한 개선은 모델의 내부 사고에 사용된 토큰 수가 증가함에 따라 지속적으로 증가했습니다. 
  • 마지막으로, 우리는 생성된 이유에서 나타나는 패턴을 질적으로 조사했습니다.
  • 이 작업을 해결하면서 우리는 다음과 같은 기여를 했습니다:
    • 1. STaR을 일반화하여 다양한 비구조적 텍스트 데이터로부터 추론을 학습하도록 했습니다. 우리가 아는 한, 이는 정교하게 큐레이팅된 추론 작업이나 특정 추론 작업 모음이 아닌, 일반적인 텍스트에서 추론을 학습하도록 언어 모델(LM)을 훈련시키는 첫 번째 연구입니다.
    • 2. 주어진 문자열의 모든 토큰 위치에서 이유(rationale)를 생성할 수 있도록 병렬 샘플링 알고리즘을 제안하고 구현하여 우리의 훈련 절차가 확장 가능하도록 만들었습니다.
    • 3. 각 생각(thought)의 시작과 끝에 맞춤형 메타 토큰을 도입하여 LM이 언제 이유를 생성해야 하고 그 이유를 기반으로 예측해야 하는지 학습할 수 있도록 했습니다.
    • 4. 현재의 다음 토큰 예측에 주어진 생각에서 나온 다음 토큰 예측을 얼마나 반영할지 회고적으로 결정하는 mixing head를 도입했습니다.
    • 5. 다중 토큰을 고려하는 비근시적 손실(non-myopic loss)이 사고(thinking)의 효과를 향상시킨다는 것을 보여주었습니다.
    • 6. 여러 작업에서 사고(thinking)를 통해 LM이 동일한 웹 텍스트로 학습된 모델보다 어려운 토큰을 더 잘 예측할 수 있으며, 사고 시간이 길어질수록 성능이 개선된다는 것을 증명했습니다.

2 Related Work

2.1 Reasoning in Language Models

  • 언어 모델을 훈련하여 어려운 과제를 해결하는 방법 중 먼저 그 과제를 추론하도록 훈련시키는 연구들이 많이 진행되었습니다. 
  • 예를 들어, Rajani et al. (2019)은 다중 선택 상식 추론 질문에 답하기 전에 인간의 추론 과정을 출력하도록 미세 조정된 사전 훈련 언어 모델이, 정답만 학습한 모델보다 뛰어난 성능을 보였음을 증명했습니다. 
  • Shwartz et al. (2020)은 언어 모델이 약간의 지침만 제공되었을 때 추가적인 감독 없이도 유용한 "사고의 연쇄(chain-of-thought)" 솔루션을 생성할 수 있음을 입증했습니다. 
  • 이후 Nye et al. (2021)은 언어 모델의 성능이 향상됨에 따라 "스크래치패드(scratchpads)"가 덜한 지침으로도 충분함을 보여주었으며, 이는 Wei et al. (2022b)에 의해 비공식적인 작업에서 강화되었고, Kojima et al. (2022)는 이러한 행동이 제로샷으로 가능하다는 것을 입증하여 더욱 강화되었습니다. 
  • 가장 최근에는 Wang & Zhou (2024)가 상식 질문 답변에서 언어 모델이 유효한 정답 토큰을 내뱉지 않도록 강제하여 "사고의 연쇄" 추론을 활용하게 할 수 있음을 보여주었습니다. 
  • 그러나 다시 한 번 이러한 접근 방식은 질문-답변 데이터셋에서만 작동하며, Wang & Zhou (2024)는 모델이 정답 토큰을 출력했는지 여부를 식별하는 데 휴리스틱에 의존했습니다. 
  • TRICE (Phan et al., 2023)와 유사하게, 우리는 이유(rationale)를 통해 목표 텍스트의 로그 가능도(log-likelihood)에서 상대적인 개선을 품질의 추정치로 사용하지만, 우리는 단순히 평균 보상을 빼고 복잡한 통제 변수는 사용하지 않습니다.

2.2 Training Language Models to Reason

  • 연구자들이 언어 모델(LM)의 추론 능력을 향상시키기 위해 사용하는 한 가지 방향은, 채굴된 추론 과정이나 추론과 유사한 데이터를 훈련에 활용하는 것입니다(Rajani et al., 2019; Wei et al., 2021a; Lewkowycz et al., 2022; Chung et al., 2022; Gunasekar et al., 2023). 
  • 이러한 접근 방식은 효과가 있음을 입증했지만 몇 가지 단점이 있습니다. 
    • 이는 수작업 주석을 필요로 하는데, 이 과정은 주석자의 능력에 민감하고 언어 모델에 적합하지 않은 오프-폴리시(off-policy) 방식입니다(즉, 언어 모델이 생성할 가능성이 있는 텍스트의 분포와 다른 추론을 사용). 
    • 또한, 이 접근 방식은 비용이 많이 들고 확장하기 어렵습니다. 
    • 주석자가 해결할 수 있는 문제보다 더 어려운 문제를 해결할 수 있는 명확한 방법도 제공하지 않습니다.
  • 추론을 가르치는 또 다른 방향은 언어 모델이 스스로 생성한 추론을 활용하는 것입니다. 
  • 이는 셀프 플레이(self-play)에 관한 방대한 문헌을 바탕으로 한 방식으로 볼 수 있습니다(Silver et al., 2017; Anthony et al., 2017; Polu & Sutskever, 2020). 
  • 이러한 방법 중 하나로는 Self-Taught Reasoner(STaR, Zelikman et al., 2022)가 있습니다. 
    • STaR은 올바른 답을 이끌어낸 모델의 추론을 반복적으로 학습시켜 점점 더 어려운 문제를 해결할 수 있음을 보여주었습니다. 
    • 이후 연구는 추가 정보나 가정을 활용하려고 했습니다. 
    • 예를 들어 Huang et al. (2022)은 STaR 알고리즘이 다수결 답변이 올바르다는 가정 하에서도 여전히 작동할 수 있음을 보여주었지만, 이는 궁극적인 성능이 낮아지는 경향이 있었습니다. 
  • 추가 연구에서는 Zelikman et al. (2022)의 결과를 일반화하였습니다.
    • 예를 들어, Uesato et al. (2022)은 잘못된 추론 과정을 필터링하는 "과정 기반" 감독이 유용하다는 것을 보여주었고, VSTaR(Hosseini et al., 2024)은 생성 과정을 안내하기 위해 검증자를 훈련시키는 것이 성능을 향상시킨다는 것을 증명했습니다. 
    • 또한, TRICE(Hoffman et al., 2024)는 문제당 여러 개의 추론 과정을 통해 올바른 답변의 주변 가능성(marginal likelihood)을 극대화하는 방식으로 성능을 높였습니다.
    • 마지막으로, 관련 연구는 수학적 진술을 만드는 제한된 환경에서 중간 추론 학습을 탐구하기도 했습니다. 
    • 여기서 모델의 중간 추론에서 나오는 진술은 유효한 수학적 진술로만 제한되었습니다(Poesia et al., 2023). 
  • 관련된 추론 작업에 대한 추가 논의는 부록 F에서 다룹니다.

2.3 Meta-tokens

  • 최근 연구들은 신경망의 맥락에서 특정 기능을 수행하도록 최적화된 맞춤형 토큰(custom tokens)의 유용성을 점점 더 많이 입증하고 있습니다. 
  • 이러한 이유로 이들은 "기능 벡터(function vectors)"라고도 불립니다(Todd et al., 2023). 
  • 이 개념의 초기 구현 중 하나는 프롬프트 튜닝(prompt-tuning)이었으며(Lester et al., 2021), 유사하게 프리픽스 튜닝(prefix-tuning, Li & Liang, 2021)도 있었습니다. 
    • 여기서는 프롬프트의 토큰에 해당하는 임베딩을 최적화하여 작업을 더 잘 수행할 수 있도록 했습니다. 
  • 다른 연구에서는 긴 프롬프트를 효율적으로 압축하기 위해 메타 토큰을 적용했습니다(Li et al., 2023; Jung & Kim, 2023). 
  • 이 연구와 가장 관련이 깊은 연구로는 Mu et al. (2024)이 있는데, 이들은 뒤의 토큰들이 앞의 토큰에 집중(attend)할 수 없을 때(즉, 맥락 압축 토큰(context compression token)) 충분한 정보를 제공하도록 토큰을 최적화했습니다. 
  • 비록 우리는 압축에 중점을 두지는 않지만, 주의(attention)에 영향을 미치고 복잡한 다운스트림 행동을 제어하는 토큰을 학습하는 문제를 공유합니다.
  • 관련된 또 다른 연구에서는 Goyal et al. (2023)이 "pause" 토큰이라는 단일 토큰을 학습하여 LM 성능을 향상시킬 수 있음을 보여주었습니다. 
    • 이 pause 토큰은 본질적으로 각 토큰을 두 개의 토큰으로 나타내는 것과 비슷합니다. 
    • 그러나 우리의 작업에서 사고(thought)를 시작하는 생각 토큰(thought tokens)과 달리, 이 pause 토큰은 사고를 시작하지 않으며, 사고 전체를 나타내는 것으로 볼 수 있습니다. 
  • 우리는 언어에서의 추론이 훨씬 더 유용하다는 것을 발견했습니다.
  • special tokens과 같은것들에 대한 연구 같음.
  • prompt tuning할때도 prefix token이라해서 단어적으로는 의미없는 토큰을 넣어서 프롬프트가 되도록 학습하는 기법도 있었던걸로 기억
    • https://velog.io/@mmodestaa/GPT-3-%EB%93%B1%EC%9E%A5%EA%B3%BC-%EA%B7%B8-%ED%9B%84-Prompting-and-Promt-Tuning-Prefix-Tuning-P-Tuning
  • contet을 압축하는 토큰을 붙이는 연구도 있었나봄 (신기하네)
  • pause 토큰을 사용하는 연구도 있다고함
    • Training language models with pause tokens 대충 그림만보면 오호? 하는 생각이 들긴함
    • 하지만 pause에서 사고 전체를 나타내는것이라 이 연구와 조금 다르다고함

3 Problem Statement

  • 이 연구에서는 시퀀스의 관찰된 각 토큰 쌍 사이에 보조 '이유(rationale)' 변수를 도입합니다. 
  • 그 후, 우리는 언어 모델의 매개변수 θ를 최적화하여 중간 사고(또는 이유)를 생성할 수 있도록 하는 것을 목표로 합니다. 
  • 이는 다음과 같이 표현됩니다:
    • 즉, i번째까지의 token과 이를 기반으로 생성된 rationale을 통해 i번째 이후의 token들을 생성하라는것
  • 원칙적으로, 이것은 이미 언어의 문자열 분포를 정확하게 모델링하는 최적의 언어 모델에 비해 이점을 제공하지 않습니다. 
  • 그러나 실제로는 이전의 많은 연구에서 언어 모델이 추론 작업에서 중간 이유를 통해 이익을 얻는다고 보여주었습니다(Nye et al., 2021; Zelikman et al., 2022; Wei et al., 2022b). 
  • 일부 연구에서는 연쇄 추론(chain-of-thought reasoning)의 효과를 설명하려고 시도했으며, 이를 "경험의 국소성(locality of experience)"에 기인한다고 설명합니다(Prystawski et al., 2024). 
  • 더 넓게 보면, 추론은 모델이 복잡한 계산을 더 작은 단계로 분해할 수 있게 해줍니다. 
  • 실질적으로, 우리는 모델이 미래 텍스트를 예측하는 데 효과적인 분해 및 계획 단계를 학습하도록 훈련합니다. 
  • 또한, 우리는 목표를 단순히 다음 토큰만 예측하는 것이 아니라 남은 시퀀스를 정확하게 예측하는 것으로 공식화합니다. 
    • 일반적인 LM objective랑 뭐가 다른가?
    • 최적의 언어 모델에서는 이 두 가지가 동일하겠지만, 우리는 비단기적(non-myopic) 공식화가 이유 학습에 더 효과적인 손실로 이어진다는 것을 발견했습니다.

4 Quiet-STaR

  • Quiet-STaR는 세 가지 주요 단계로 작동합니다 (그림 1):
    • think을 토큰별로하는데, 하나의 생각만 하는게 아니라 여러개 해서 샘플링한다.
  • Parallel rationale generation (think, Subsection 4.2)
    • 입력 시퀀스 \(x_{0:n}\)의 \(n\)개의 토큰 \(x_i\)에 대해 병렬로 길이 \(t\)의 \(r\)개의 추론 \(c_i = (c_{i1}, ..., c_{it})\)을 생성하여, 총 \(n \times r\)개의 추론 후보를 만듭니다. 
    • 각 추론의 시작과 끝을 표시하기 위해 학습된 <|startofthought|> 및 <|endofthought|> 토큰을 삽입합니다.
  • Mixing post-rationale and base predictions (talk, Subsection 4.3): 
    • 각 추론 후의 은닉 상태 출력을 기반으로, 'mixing head'를 훈련시킵니다. 
    • 이는 얕은 다층 퍼셉트론(MLP)으로, 다음 토큰 예측 로짓에서 추론 후 로짓을 기본 언어 모델 로짓과 비교하여 얼마나 반영할지 결정하는 가중치를 생성합니다. 
    • 이 접근법은 추론이 도입됨에 따른 분포 변화로 인해 파인튜닝 초기에 발생할 수 있는 어려움을 완화합니다.
  • Optimizing rationale generation (learn, Subsection 4.4): 
    • 미래 텍스트의 확률을 높이는 추론을 생성할 수 있도록 추론 생성 매개변수(시작/끝 토큰 및 LM 가중치)를 최적화합니다. 
    • 우리는 REINFORCE를 사용하여 미래 토큰 예측에 미치는 영향을 기반으로 추론에 학습 신호를 제공합니다. 
    • 분산을 줄이기 위해, 생각 후의 토큰뿐만 아니라 이후의 토큰 예측 가능성을 손실에 포함시키는 teacher-forcing trick를 적용합니다.

4.2 Parallel Generation

  • Quiet-STaR에서 중요한 도전 과제 중 하나는 입력 시퀀스의 각 토큰 위치에서 효율적으로 추론을 생성하는 것입니다. 
    • 이를 단순하게 처리하면 각 토큰에 대해 별도의 순방향 패스를 수행해야 하며, 이는 긴 시퀀스의 경우 계산적으로 감당할 수 없게 됩니다.
  • 우리는 먼저 언어 모델의 추론 패스가 모든 입력 토큰에 대해 다음 토큰의 확률 분포를 생성한다는 점을 관찰하여 매우 병렬적인 생성을 허용합니다. 
    • 이를 통해 입력의 각 토큰에서 하나의 다음 토큰을 샘플링할 수 있습니다. 
  • 각 토큰의 후속 토큰을 생성하면 원래 시퀀스를 단순히 계속할 수는 없습니다. 
    • 예를 들어, “< bos > the cat sat”라는 문장에서 각 토큰 다음의 토큰을 예측한다고 하면, "yes orange saw down"과 같은 결과가 나올 수 있습니다. 
      • "The" -> "dog"
      • "cat" -> "ran"
      • "sat" -> "quickly"
      • "on" -> "under"
      • "the" -> "big"
      • "mat" -> "door"이때, 각 토큰 다음에 올 수 있는 "가능한" 후속 토큰이 예측되지만, 이들은 모두 다른 문맥에서 나온 것입니다. 즉, "The dog", "cat ran", "sat quickly" 등은 각 접두사(앞 토큰들)에 대한 합리적인 후속이지만, 이 결과들을 조합한 "dog ran quickly under big door"는 원래 문장의 자연스러운 흐름과는 무관합니다. 이처럼 각 예측은 원래 문장의 "반사실적" 연속이 됩니다.
    • 각각의 후속 토큰은 해당 시퀀스의 접두사에 대해 합리적인 다음 토큰이 될 수 있지만, 이 토큰들의 목록은 이러한 접두사의 "반사실적(counterfactual)" 연속을 나타내는 세트입니다. 
    • 그러나 우리는 이러한 연속을 활용하여 관찰된 각 토큰에 대한 숨겨진 생각을 생성할 수 있습니다.
  • 이를 효율적으로 수행하기 위해 각 순방향 패스를 캐시하고, 이전 주의(attention) 마스크에 대각선 주의 마스크를 연결합니다. 
  • 이제 생성된 각 토큰은 그것을 생성하는 데 사용된 모든 토큰과 자신에게 주의를 기울입니다(단, 다른 "반사실적" 경로에 있는 토큰에는 주의를 기울이지 않습니다). 
  • 또한, 이 병렬화된 다음 토큰 샘플링 절차는 메모리가 허용하는 한 임의의 횟수만큼 반복할 수 있습니다. 
  • 이 절차는 그림 3에 시각화되어 있으며, 추가적으로 이 알고리즘을 더 빠르게 만드는 방법은 부록 B에서 강조됩니다.
    • think? 단계라하는 여기서, 각 토큰에서 생각하는 path을 생성하는데 이전 토큰들이 입력으로 들어감
    • 즉 d'을 생성할때는 a,b,c,d가 입력임. d'와 같이 think할때는 special token이 들어가서 생각이라고 LM에게 알려주는 듯
    • next token인 e을 예측할때, a,b,c,d,a',b',c',d'가 들어가는 듯
    • 이런 방법은 근데 엄청 메모리가 많이 필요할거 같은데..??

4.3 “Mixing” (Residual) Heads

  • 사전 훈련된 모델을 시작할 때, 생각(thoughts)은 초기에는 분포에서 벗어나기 때문에 언어 모델링 성능에 해를 끼칠 수 있습니다. 
  • 이러한 "생각"으로의 전환을 부드럽게 하기 위해, 우리는 생각을 포함한 예측과 포함하지 않은 예측 사이에서 학습된 보간(interpolation)을 도입합니다.
  • 구체적으로는, end-of-thought token을 나타내는 토큰의 은닉 상태와 원래 텍스트 토큰의 은닉 상태를 바탕으로, 혼합 헤드(mixing head)가  post-thought prediction logits이 얼마나 사용될지를 결정하는 가중치를 출력합니다. 
  • 이 혼합 헤드는 얕은 다층 퍼셉트론(MLP)으로 구성되며, 각 토큰에 대해 스칼라 값을 출력합니다. 구현 세부 사항은 부록 A에 포함되어 있습니다.

4.4 Optimizing Rationale Generation 

4.4.1 Optimizing Start-of-Thought and End-of-Thought Tokens

  • <|startofthought|>와 <|endofthought|> 토큰은 모델의 추론(rationale) 생성을 제어하는 학습된 메타 토큰 역할을 합니다. 
  • 특히 <|startofthought|> 토큰의 표현을 최적화하는 것이 매우 중요하지만, 추론 토큰들이 이산적인(discrete) 성질을 가지기 때문에 도전적입니다.
  • 이 토큰들의 임베딩은 일반적으로 텍스트 데이터에서 생각이나 중단을 나타내는 "em 대시 ”−−−”에 해당하는 임베딩으로 초기화됩니다. 
  • 이를 통해 언어 모델이 이미 가지고 있는 지식을 활용할 수 있습니다. 
  • 또한, 이러한 임베딩이 더 빨리 최적화될 수 있도록, 업데이트 단계에서 해당 임베딩의 그래디언트에 하이퍼파라미터 가중치를 적용합니다.
  • 직관적으로, <|startofthought|> 토큰은 모델을 "생각 모드"로 전환시키는 역할을 하며, <|endofthought|> 토큰은 모델에게 "생각이 끝났다"고 알리는 역할을 한다고 이해할 수 있습니다.
  • 즉 special token으로 think라는 것의 시작과 끝을 LM에게 알려줘서 단순 next token과 차별을 둔다

4.4.2 Non-myopic Scoring and Teacher-forcing 

  • 우리는 생각(thought)이 모든 토큰을 예측하는 데 유용하지 않을 것이라고 예상하기 때문에, 모델의 보상이 생각 이후의 정확한 다음 단어보다는 이후의 의미적 내용에 더 의존하길 원합니다. 
  • 여기에는 두 가지 주요 도전 과제가 있습니다.
  • 첫째, 일반적인 트랜스포머 기반 언어 모델링과 달리, 병렬 샘플링 전략의 결과로 특정 다음 토큰 예측에 해당하는 생각들만 그 예측으로부터 그래디언트를 받습니다. 
    • 이 문제를 해결하기 위해, 이전 토큰들을 샘플링하여 미래 토큰에 대한 손실 항을 추가할 수 있습니다. 
    • 보통 바로 다음 토큰을 생성하도록 loss을 계산하는데 이는 step=1인 개념. 여기서는 step=ntrue(알고리즘상의 변수)만큼 생성하도록 loss을 계산한다는 것
    • 그러나 이는 일반적으로 언어 모델링의 엔트로피를 크게 증가시키고, 모델이 이전 토큰들을 부분적으로 무시하도록 훈련되기 때문에, 생성된 텍스트의 품질이 낮아지는 결과를 초래할 수 있습니다.
      • ### 예시:
      • #### 1. **기본 언어 모델 예측** (엔트로피 낮음):
      • 먼저, 언어 모델이 일반적인 방식으로 다음 단어를 예측하는 상황을 가정해보겠습니다.
      • **입력 문장**:
      • "The cat sat on the"
      • **기본 언어 모델 예측**:
      • 이 경우, 모델은 "mat"과 같은 단어를 예측할 가능성이 큽니다. 왜냐하면, 앞의 문맥에서 "The cat sat on the"라는 구문은 매우 일반적이고, 자연스럽게 이어질 단어가 "mat"일 가능성이 높기 때문입니다.
      • - **가능한 선택지**: ["mat", "floor", "ground"]
      • - **엔트로피**: 낮음 (대부분의 확률이 "mat"에 집중됨).
      • #### 2. **미래 토큰에 대한 손실 항 추가** (엔트로피 증가):
      • 이제, 미래 토큰에 대한 손실 항을 추가하는 상황을 생각해봅시다. 이 방식에서는 다음 몇 개의 토큰에 대한 예측을 한 번에 고려하게 됩니다.
      • **입력 문장**:
      • "The cat sat on the"
      • 이때, 모델은 더 먼 미래까지 고려하려고 합니다. 예를 들어, "The cat sat on the mat"이 아니라, 다음 문장까지 고려해서 예측할 수 있습니다.
      • 모델이 미래의 토큰도 예측하려고 하면, "The cat sat on the mat" 다음에 나올 수 있는 다양한 가능성(예: "and started to sleep", "but then ran away", "while the dog barked") 등을 모두 고려해야 합니다.
      • 이 경우, 이전 문맥인 "The cat sat on the"가 아닌 미래의 문장까지 고려하려고 하기 때문에, 다음에 나올 단어가 **다양해질 수 있습니다**. 예를 들어:
      • - **가능한 선택지**: ["mat", "dog", "grass", "rock", "lake"]
      • (다양한 미래 상황을 예측하게 되면서 선택지가 더 많아짐)
      • 이렇게 되면, 다음 단어를 예측할 때 더 많은 가능성을 고려해야 하므로, 결과적으로 **엔트로피가 증가**합니다. 모델이 "mat"을 예측할 가능성이 여전히 존재하지만, 이제는 "dog"나 "lake" 등 다양한 가능성도 추가적으로 고려하게 되면서 불확실성이 증가합니다.
      • #### 3. **결과적으로 엔트로피가 증가하는 이유**:
      • 미래의 토큰까지 손실 항에 포함시키면, 모델이 단순히 다음 토큰만 예측하는 것이 아니라, 더 먼 미래까지 염두에 두고 예측을 시도합니다. 이로 인해 각 토큰에 대해 더 많은 가능성을 고려하게 되어, 모델의 예측이 더 불확실해지고 엔트로피가 증가하게 됩니다.
      • 요약하자면, **미래 토큰에 대한 손실 항 추가**는 더 많은 가능성을 모델이 고려하게 만들고, 이로 인해 언어 모델이 다음 토큰을 예측하는 데 있어 불확실성이 증가하여 엔트로피가 높아지게 됩니다.
  • 대신에, 우리는 병렬 주의 마스크(parallel attention mask)를 사용하여 실제 다음 토큰의 로그 확률을 계산합니다. 
    • 여기서 교사 강제(teacher forcing)를 적용하여 모델이 정확한 정답 토큰을 선택했다고 가정합니다(이는 트랜스포머 기반 언어 모델링에서 암묵적으로 사용되는 방식입니다). 
    • 또한, 각 미래 토큰의 손실은 생각 끝 토큰과 이전에 관찰된 토큰으로부터 계산된 혼합 가중치에 의존합니다. 
  • 손실에 포함되는 미래 토큰의 수는 하이퍼파라미터입니다.
  • 우리는 시작 토큰과 끝 토큰을 삽입할 때도 동일한 교사 강제 기법을 적용합니다. 
  • 이 절차는 그림 4에서 시각화됩니다.
    • 우리는 알고리즘의 단일 순방향 패스를 시각화합니다. 실선은 언어 모델의 계산을 나타내며, 점선은 교사 강제를 통해 토큰이 삽입되는 것을 나타냅니다. 혼합기는 **혼합 헤드(mixing head)**를 의미합니다. 특히, 우리는 세 개의 토큰을 앞서 예측하는 과정을 시각화합니다.
    • 생각(thought) 생성 과정은 그림 1과 그림 3에서 더 자세히 설명됩니다.

4.4.3 Objective 

  • REINFORCE를 사용하여 유용성에 기반한 근거(rationales)의 가능성을 최적화합니다. 
  • 이는 이전에 관찰된 토큰과 특정 근거를 고려했을 때, ntrue개의 다음 진짜 토큰 Xj+1:j+ntrue+1에 대한 언어 모델의 로그 가능도(log-likelihood)를 의미합니다. 
    • 알고리즘 1에서 설명된 대로, 생각한 후의 mixed prediction probabilities을 p_talk_j:j+ntrue로 약칭합니다. 
    • 여기서 ntrue는 예측할 미래 토큰개수. 헷갈리면 그냥 fixed 숫자라고 생각해도 됨
    • l개 토큰이 주어진다고 하자
    • 생각없이
      • l개의 토큰에 대한 hidden state h^init을 계산하고, j:j+ntrue에 해당하는 next token을 생성하도록 하는 부분이 logp^init_j:j+ntrue
    • 생각있이
      • 각 토큰에서 Tj라 하는 생각 토큰들을 생성한다. (j=1~l)
      • j까지의 토큰; Tj; j이후의 ntrue토큰을 concat하여 hiddens tate h^thought을 생성한다. (즉 기존 시퀀스 사이에 Tj을 끼는 것과 같다)
      • hidden state h^thought_j:j+ntrue에서부터 next token(j+1번째 토큰일듯)을 생성하는게 logp^thought_j:j+ntrue
    • 생각없이와 생각있이(Tj)을 혼합하도록 weight을 MLP로 구하는거 같음
    • logp^init_j:j+ntrue와 logp^thought_j:j+ntrue의 가중치합으로 mixed 예측하고 이를 p^talk_j라고 하고 이를 이용한 loss가 NLL loss임
    • 이거를 할 떄, think에 해당하는 Tj는 한개의 시퀀스가 아니라 여러 개의 시퀀스임
    • 이 시퀀스에 해당하는 logp^talk의 평균을 각 think 시퀀스에서 빼준것을 rj로 정의 (REINFORCMENT에서 분산을 줄여주는 효과)
    • 이를 통해 입력된 시퀀스 X:j에서 j번째 think Tj을 생성하는 REINFORCMENT loss을 정의함. 
      • 개념으로 생각하면 next token이 잘 생성되도록 think 시퀀스가 학습되는 식인듯
    • NLL과 REINFORCEMENT loss을 더해서 학습
  • 분산을 줄이기 위해 입력 시퀀스의 각 토큰에 대해 여러 개의 근거 확장을 생성합니다(Phan 등(2023)의 TRICE에서 느슨하게 영감을 받았습니다). 
  • 따라서 각 근거 Tj에 대한 보상 rj는 해당 토큰에 대한 모든 근거의 평균(p_talk_j:j+ntrue)과의 차이로 정의됩니다:
  • 이 보상은 REINFORCE 손실 항에 사용되어, 평균보다 더 나은 성능을 내는 근거의 가능성을 높이도록 언어 모델의 매개변수 θ를 업데이트합니다:
  • 우리는 음의 보상을 REINFORCE 손실 항에서 제외하는 것이 더 안정적인 학습을 유도한다는 것을 발견했지만, 이는 약간의 편향을 도입할 수 있습니다. 
    • 오히려 평균을 안빼는게 안정적인 학습이라는 듯. 하지만 편향을 막기 위해서 넣은듯
  • 이 손실 항은 모델이 해당 토큰에 대해 생성된 모든 근거의 평균 예측과 비교하여 미래 토큰의 예측을 개선하는 근거를 생성하도록 유도합니다. 
  • 이 손실로부터 얻은 그래디언트는 LM(언어 모델) 매개변수와 start-of-thought 및 end-of-thought 토큰 임베딩을 업데이트하는 데 사용됩니다. 
  • 또한, start-of-thought와 end-of-thought 토큰 임베딩의 최적화를 가속화하기 위해 이들에 대한 그래디언트에 (하이퍼파라미터) 가중치가 적용됩니다. 
  • 이러한 매개변수를 반복적으로 최적화함으로써 Quiet-STaR는 훈련 동안 더 유용한 근거를 생성하는 모델을 학습시킵니다. 
  • 마지막으로, LM이 talking heads를 최적화하고 기본 LM head에 대해 다음 토큰 예측 신호를 받을 수 있도록 로그 가능도 손실(L_NLL_j)도 포함합니다.

5 Experiments and Results

  • 직관적으로, 모든 토큰이 동일한 양의 사고(thought)를 필요로 하는 것은 아닙니다. 예를 들어, “the person is run-”라는 문장을 고려해보면, "ing" 이외의 다른 토큰이 올 확률이 존재하지만, 이 문장이 독립적으로 주어진 경우, 추가적인 사고가 잘 훈련된 모델의 예측 성능을 개선할 가능성은 낮습니다. 실제로, 대부분의 온라인 텍스트의 대부분의 덩어리에 대해 추가적인 사고는 거의 또는 전혀 영향을 미치지 않는다고 추측합니다. 초기 탐색 과정에서 Quiet-STaR가 모든 토큰에 동일하게 이익을 주지 않는다는 것을 관찰했습니다. 따라서, 우리는 우리의 접근 방식이 사고를 필요로 하는 토큰의 예측에 유용한지를 조사하기 위해 실험을 설계했습니다. 우리는 1) Quiet-STaR가 추론을 필요로 하는 데이터셋에서 언어 모델의 직접적인 예측 능력을 향상시키는지, 그리고 2) 사고 토큰으로 인한 영향을 분포를 평가했습니다. 모든 실험은 Mistral 7B의 기본 버전을 시작으로 진행했습니다(Jiang et al., 2023).
  • 대부분의 실험은 OpenWebMath(Paster et al., 2023)에서 훈련하여 진행했습니다. OpenWebMath는 기술적인 웹페이지를 강조하는 크롤로, 우리는 이 데이터셋이 추론에서 이익을 볼 가능성이 높은 토큰의 밀도가 높을 것으로 예상했습니다. 실험 결과도 이를 뒷받침했습니다. 또한, 다양한 텍스트가 포함된 널리 사용되는 LM 프리트레이닝 코퍼스인 C4(Raffel et al., 2020)에서도 Quiet-STaR를 평가했으며, 여기서도 상당한 이익을 보여주었지만, 규모는 좀 더 작았습니다.

5.1 Downstream Performance 

5.2 Improvement Distribution

5.3 Quiet-STaR and Chain-of-Thought

5.4 Examples

  • Quiet-STaR에서는 사고(thought)가 인간이 해석할 수 있도록 명시적인 정규화가 존재하지 않지만, 사고는 언어 모델을 훈련하는 동일한 트랜스포머에서 생성되기 때문에 적어도 부분적으로 이해할 수 있을 가능성이 높습니다. 이러한 설계 선택이 훈련 안정성에 이점이 되는 이유는 부록 I에서 논의합니다. 참고로, OpenWebMath에서 미래의 토큰 예측에 도움이 되었던 생성된 사고의 예를 포함합니다.
  • 예를 봐도 뭘 말하고 싶은지 느낌이 와닿지는 않는듯
  • 첫 번째 사례에서는 마그네슘으로 시작하여 마그네슘 질화물을 생성해야 한다는 점을 상기시킴으로써, 절차의 첫 단계가 마그네슘을 가열하는 것이라는 예측을 더 잘 할 수 있도록 합니다. 
    • <s> # 마그네슘은 질소와 반응하여 마그네슘 질화물을 형성합니다.
    • 이 반응의 화학식은 Mg + N₂ → MgN₂입니다.
    • 이 반응의 생성물은 무엇인가요?
    • 2016년 1월 12일
    • 마그네슘 질화물의 화학식은 $Mg_{3}N_{2}$입니다.
    • #### 설명:
    • 많은 활성 금속처럼, 마그네슘 질화물은 <|startofthought|> 1 --, 따라서 마그네슘 질화물을 형성하는 반응의 화학식은 $Mg + N₂ \to <|endofthought|> 금속을 가열하여 형성됩니다 (fier' 
  • 일부 경우, 가장 유용한 사고는 목표 텍스트와 더 밀접하게 일치하는 근접 연속으로 나타납니다.
    • 정수 \( n \)이 홀수라는 것은 \( n = 2k + 1 \)의 형태로 표현될 수 있는 경우를 말합니다. (여기서 \( k \)는 정수)
    • \( A = B \)임을 증명하기 위해서는 두 가지를 보여야 합니다:
    • 1. \( A \subseteq B \)
    • 2. \( B \subseteq A \)
    • 이 중 첫 번째 조건은 학생들에게 <|startofthought|> 어떤 면에서 - 더 어려운<|endofthought|> 경우가 많습니다.
  • 마지막으로, CommonsenseQA를 답변하는 예를 포함합니다. 주목할 만한 점은 이 사고가 질문을 읽는 동안 발생했으며, 따라서 최종 답변을 예측하는 데 사용되지 않았다는 것입니다.
    • **질문:**
    • 같은 사람과 같은 주제에 대해 계속해서 이야기하는 것은 무엇인가요?
    • <|startofthought|> (a) 일대일 상관관계 (b) 일대일 <|endofthought|> 어떤 것을 할 수 있는 것인가요?

6 Limitations 

  • 이 연구는 추론 학습을 위한 새로운 프레임워크를 제안하며, 다양한 메타 학습 문제에 대한 해결책을 탐구합니다. 그러나 이러한 문제를 해결하기 위해 몇 가지 단순화가 필요했습니다. 예를 들어, 모델이 처음부터 학습할 때 이러한 기술이 작동하는지를 이해하는 것이 중요합니다. 또한 우리는 Quiet-STaR를 70억 개의 매개변수를 가진 강력한 모델에만 적용했습니다. 더 나은 모델에 동일한 기술을 적용하면 비례적으로 더 좋은 결과를 얻을 가능성이 높습니다. 이는 추론을 통한 성과 향상이 자주 관찰되는 것과 일치합니다(Wei et al., 2022a).
  • Quiet-STaR는 추가 토큰을 생성하기 전에 많은 토큰을 생성하는 상당한 오버헤드를 발생시킵니다. (성능 결과는 Appendix C를 참조하십시오.) 그러나 이는 장점으로도 볼 수 있습니다. 일반적으로 언어 모델은 현재 맥락에 따라 다음 토큰을 생성할 수 있으며, 샘플링 품질을 개선하기 위한 기술이 있지만, 추가적인 계산을 활용해 다음 토큰 예측을 향상시키는 일반적인 방법은 없습니다. 현재 구현에서는 언제 합리적 사고를 생성하거나 종료할지를 동적으로 예측하는 기능을 지원하지 않지만, 이는 자연스러운 확장이 될 것입니다. 예를 들어, 믹싱 헤드가 사고 전 기본 언어 모델의 예측이었다면, 사고 후가 아니라 사고 전이라면, 통합되지 않을 사고를 생성하지 않도록 임계값을 적용할 수 있습니다. 사고의 유용성을 예측하는 것은 이미 사고를 생성한 후에 하는 것이 더 간단하므로, 이는 더 어려운 작업일 것으로 예상합니다.

7 Conclusion

  • Quiet-STaR는 언어 모델이 일반적이고 확장 가능한 방식으로 추론을 학습할 수 있는 방향으로 나아가는 단계입니다. 
  • 특정 데이터셋에 대해 좁게 특화하기보다는 다양한 웹 텍스트에 내재된 풍부한 추론 작업 스펙트럼에서 학습함으로써, Quiet-STaR는 더 강력하고 적응 가능한 언어 모델을 지향합니다. 
  • 우리의 결과는 이 접근 방식의 가능성을 입증하며, Quiet-STaR는 하위 추론 성능을 향상시키면서 질적으로 의미 있는 합리적 사고를 생성합니다. 
  • 우리는 이것이 여러 잠재적 미래 방향을 열어준다고 믿습니다. 
  • 예를 들어, 미래 토큰의 예측을 더욱 개선하기 위해 사고를 집합적으로 사용하는 것을 목표로 할 수 있습니다. 
  • 또한, 언어 모델이 사고가 유용할 때를 예측할 수 있다면, 예를 들어 믹싱 헤드를 예측 이전에 배치함으로써, 예측된 믹싱 가중치를 사용해 생성 중에 동적으로 계산 자원을 할당할 수 있습니다.
  • 향후 연구는 이러한 통찰을 바탕으로 언어 모델과 인간과 유사한 추론 능력 간의 간극을 좁히는 데 기여할 수 있습니다.

Reference

댓글