NL-345, Mamba: Linear-Time Sequence Modeling with Selective State Spaces, CoLM 2024

Abstract

파운데이션 모델은 현재 딥러닝의 흥미로운 응용 대부분을 구동하고 있으며, 거의 예외 없이 Transformer 아키텍처와 그 핵심인 attention 모듈에 기반한다. 긴 시퀀스에서 Transformer의 계산 비효율성을 해결하기 위해 linear attention, gated convolution 및 recurrent 모델, structured state space model(SSM)과 같은 많은 subquadratic-time 아키텍처가 개발되어 왔지만, 언어와 같은 중요한 모달리티에서는 attention만큼 좋은 성능을 보이지 못했다. 우리는 이러한 모델들의 핵심 약점이 content-based reasoning을 수행하지 못하는 데 있음을 확인하고, 이를 개선하기 위한 몇 가지 방법을 제안한다. 첫째, SSM 파라미터가 입력의 함수가 되도록 하는 단순한 변경만으로도 discrete modality에서의 약점을 해결할 수 있으며, 모델이 현재 토큰에 따라 시퀀스 길이 차원을 따라 정보를 선택적으로 전파하거나 잊을 수 있게 한다. 둘째, 이 변경으로 인해 효율적인 convolution을 사용할 수 없게 되지만, 우리는 recurrent mode에서 hardware-aware parallel algorithm을 설계한다. 우리는 이러한 selective SSM을 attention은 물론 MLP block조차 없는 단순화된 end-to-end neural network architecture인 Mamba에 통합한다. Mamba는 빠른 inference, 즉 Transformer보다 5배 높은 throughput을 보이며, 시퀀스 길이에 대해 선형적으로 scaling된다. 또한 실제 데이터에서 million-length sequence까지 성능이 향상된다. 일반적인 sequence model backbone으로서 Mamba는 언어, 오디오, 유전체학 등 여러 모달리티에서 state-of-the-art 성능을 달성한다. Language modeling에서 Mamba-3B 모델은 동일한 크기의 Transformer를 능가하며, pretraining과 downstream evaluation 모두에서 두 배 크기의 Transformer와 맞먹는 성능을 보인다.

1 Introduction

파운데이션 모델(Foundation Models, FMs)은 대규모 데이터로 사전학습(pretraining)된 뒤 다운스트림 작업에 적응되는 대형 모델로서, 현대 머신러닝의 효과적인 패러다임으로 자리 잡았다. 이러한 FM의 백본(backbone)은 대개 시퀀스 모델(sequence model)이며, 언어, 이미지, 음성, 오디오, 시계열, 유전체(genomics) 등 다양한 도메인의 입력 시퀀스를 처리한다 (Brown et al. 2020; Dosovitskiy et al. 2020; Ismail Fawaz et al. 2019; Oord et al. 2016; Poli et al. 2023; Sutskever, Vinyals, and Quoc V Le 2014). 이 개념 자체는 특정 모델 아키텍처에 종속되지 않지만, 현대의 FM은 거의 대부분 Transformer(Vaswani et al. 2017)와 그 핵심 구성요소인 attention layer(Bahdanau, Cho, and Bengio 2015)에 기반하고 있다.

Self-attention의 효과성은 문맥 윈도우(context window) 내부에서 정보를 조밀하게 전달(route)할 수 있다는 능력에서 비롯되며, 이를 통해 복잡한 데이터를 모델링할 수 있다. 그러나 이러한 특성은 근본적인 단점도 함께 가져온다. 첫째, 유한한 윈도우 밖의 정보를 모델링할 수 없으며, 둘째, 윈도우 길이에 대해 계산량이 이차적으로(quadratically) 증가한다. 이러한 문제를 해결하기 위해 attention의 더 효율적인 변형들이 대거 제안되었으나(Tay, Dehghani, Bahri, et al. 2022), 대부분 attention의 핵심적인 장점을 희생하는 대가를 치렀다. 현재까지 이들 변형 중 어느 것도 다양한 도메인에서 대규모로 attention만큼 효과적이라는 것이 실증적으로 입증되지는 못했다.

최근에는 structured state space sequence models(SSMs) (Gu, Goel, and Ré 2022; Gu, Johnson, Goel, et al. 2021)가 시퀀스 모델링을 위한 유망한 아키텍처 계열로 부상하였다. 이러한 모델은 recurrent neural network(RNN)와 convolutional neural network(CNN)의 결합으로 해석될 수 있으며, 고전적인 state space model(Kalman 1960)에서 영감을 받았다. 이 계열의 모델은 recurrence 또는 convolution 형태로 매우 효율적으로 계산될 수 있으며, 시퀀스 길이에 대해 선형 혹은 준선형(linear or near-linear)으로 확장된다. 또한 특정 데이터 모달리티에서는 장거리 의존성(long-range dependencies)을 모델링할 수 있는 원리적인 메커니즘을 제공하며(Gu, Dao, et al. 2020), Long Range Arena(Tay, Dehghani, Abnar, et al. 2021)와 같은 벤치마크를 지배해 왔다. 다양한 형태의 SSM(Gu, Goel, and Ré 2022; Gu, Gupta, et al. 2022; Gupta, Gu, and Berant 2022; Y. Li et al. 2023; Ma et al. 2023; Orvieto et al. 2023; Smith, Warrington, and Linderman 2023)은 오디오와 비전 같은 연속 신호 데이터 도메인에서 성공적이었다(Goel et al. 2022; Nguyen, Goel, et al. 2022; Saon, Gupta, and Cui 2023). 그러나 텍스트처럼 이산적(discrete)이며 정보 밀도가 높은 데이터에서는 상대적으로 성능이 떨어졌다.

SSM이라는게 원래 존재했군?

본 논문에서는 선택적 상태공간 모델(selective state space models)이라는 새로운 계열을 제안하며, 이를 통해 Transformer 수준의 모델링 성능을 달성하면서도 시퀀스 길이에 대해 선형적으로 확장되는 모델을 구현한다.

선택 메커니즘(Selection Mechanism).
우선, 기존 모델들의 핵심적인 한계를 지적한다. 그것은 입력에 의존하는 방식(input-dependent manner)으로 데이터를 효율적으로 선택(select)할 수 없다는 점이다. 즉, 특정 입력에 집중하거나 무시하는 능력이 부족하다. 우리는 selective copy 및 induction heads와 같은 중요한 합성(synthetic) 태스크에서 얻은 직관을 바탕으로, SSM의 파라미터를 입력의 함수로(parameterizing the SSM parameters based on the input) 만드는 간단한 선택 메커니즘을 설계하였다. 이를 통해 모델은 불필요한 정보를 걸러내고(filter out irrelevant information), 중요한 정보는 무기한 기억할 수 있게 된다.

하드웨어 친화적 알고리즘(Hardware-aware Algorithm).
하지만 이러한 단순한 변경은 계산 측면에서 기술적 문제를 야기한다. 사실 기존 SSM들은 계산 효율성을 위해 반드시 시간 및 입력 불변(time- and input-invariant)이어야 했다. 우리는 convolution 대신 recurrence 기반의 scan 연산을 사용하는 하드웨어 친화적 알고리즘을 통해 이 문제를 해결한다. 또한 GPU 메모리 계층 구조 사이에서 발생하는 IO 접근을 줄이기 위해, 확장된 상태(expanded state)를 실제로 메모리에 물질화(materialize)하지 않는다. 그 결과, 제안한 구현은 이론적으로도(모든 convolution 기반 SSM이 pseudo-linear scaling을 가지는 반면, 우리는 sequence length에 대해 truly linear scaling을 달성) 그리고 실제 현대 하드웨어 상에서도(A100 GPU에서 최대 3배 빠름) 기존 방법보다 더 빠르다.

아키텍처(Architecture).
또한 우리는 기존의 깊은 시퀀스 모델 아키텍처를 단순화하였다. 기존 SSM 아키텍처(Dao, Fu, Saab, et al. 2023)의 설계와 Transformer의 MLP block을 하나의 블록으로 결합함으로써, attention은 물론 MLP block조차 포함하지 않는 단순하고 균일한(homogenous) 아키텍처인 Mamba를 제안한다.

Selective SSM, 그리고 이를 기반으로 한 Mamba 아키텍처는 완전한 recurrent model이며, 일반적인 시퀀스 기반 foundation model의 백본으로 적합한 여러 특성을 가진다.

  1. 높은 품질(High quality): 선택 메커니즘(selectivity)은 언어 및 유전체와 같은 정보 밀도가 높은 모달리티에서 강력한 성능을 제공한다.

  2. 빠른 학습 및 추론(Fast training and inference): 학습 시 계산량과 메모리 사용량이 시퀀스 길이에 대해 선형적으로 증가하며, 추론 시 autoregressive unrolling은 이전 요소들의 캐시가 필요 없으므로 step당 상수 시간(constant time)만 소요된다.

  3. 긴 문맥(Long context): 품질과 효율성이 결합되어 실제 데이터에서 최대 백만 길이(sequence length 1M)까지 성능 향상을 이끌어낸다.

우리는 Mamba가 범용 시퀀스 FM 백본으로서 잠재력을 가진다는 점을 여러 모달리티와 설정에서 실험적으로 검증하였다.

  • 합성 태스크(Synthetics). selective copying과 induction heads 같은 중요한 합성 태스크에서, Mamba는 문제를 쉽게 해결할 뿐 아니라 백만 토큰 이상의 길이까지 무한히 일반화(extrapolate)할 수 있다.

  • 오디오와 유전체(Audio and Genomics). Mamba는 오디오 waveform과 DNA sequence 모델링에서 SaShiMi, Hyena, Transformer 같은 기존 최고 성능 모델을 능가하며, pretraining 품질과 downstream metric 모두에서 우수한 결과를 보였다. 특히 긴 문맥을 사용할수록 최대 백만 길이까지 성능이 향상되었다.

  • 언어 모델링(Language Modeling). Mamba는 Transformer 수준의 성능을 진정으로 달성한 최초의 선형 시간(linear-time) 시퀀스 모델이다. 최대 1B 파라미터 규모의 scaling law 실험에서, Mamba는 LLaMa(Touvron et al. 2023) 기반의 강력한 Transformer 학습 레시피들을 포함한 다양한 baseline을 능가하였다. Mamba language model은 유사한 크기의 Transformer보다 5배 높은 생성 처리량(generation throughput)을 가지며, Mamba-3B는 자기 크기의 두 배에 달하는 Transformer와 동등한 품질을 달성하였다(예: Pythia-3B 대비 commonsense reasoning 평균 4점 향상, 심지어 Pythia-7B를 초과).

모델 코드와 사전학습된 체크포인트는 https://github.com/state-spaces/mamba 에 공개되어 있다.















Reference

댓글