NL-346, On the Parameterization and Initialization of Diagonal State Space Models (S4D), NeurIPS 2026

S4D를 처음부터 이해하기: “상태 업데이트”에서 “Convolution”까지

이번 글은 S4D: On the Parameterization and Initialization of Diagonal State Space Models 논문을 처음 보는 사람을 위한 설명이다.
수식이 많이 나오지만, 목표는 하나다.

S4D는 긴 시퀀스를 잘 처리하기 위해, 입력을 내부 기억에 저장하고, 그 기억을 효율적인 convolution kernel로 바꿔 계산하는 모델이다.

이 글에서는 특히 네가 헷갈려 했던 부분들을 중심으로 설명한다.

  • (x'(t))가 무엇인지

  • 왜 갑자기 (y)가 (u)의 함수처럼 보이는지

  • (K(t)=Ce^{tA}B)가 어디서 나오는지

  • S4D는 기존 S4에서 무엇을 단순화한 것인지

논문은 S4D를 기존 S4보다 훨씬 단순한 diagonal SSM으로 만들면서도, 성능은 거의 유지할 수 있음을 보인다.


1. S4D가 풀고 싶은 문제

우리가 다루고 싶은 데이터는 보통 시퀀스다.

예를 들어:

문장:      [나는, 오늘, 학교에, 갔다]
오디오:    [0.01, 0.03, -0.02, ...]
주가:      [100, 101, 99, 102, ...]
이미지:    픽셀을 한 줄로 펼친 sequence

이런 데이터의 핵심은 앞에서 나온 정보가 뒤에 영향을 줄 수 있다는 점이다.

예를 들어 문장에서:

"철수는 영희에게 책을 빌려주었다. 그는 ..."

여기서 “그”가 누구인지 알기 위해서는 앞부분 기억이 필요하다.

오디오도 마찬가지다. 지금 소리 하나만 보고는 단어를 알 수 없다. 앞뒤의 긴 패턴을 같이 봐야 한다.

그래서 S4D는 이런 일을 하려 한다.

입력 시퀀스 u
      ↓
과거 정보를 잘 기억
      ↓
출력 시퀀스 y 생성

2. State Space Model이란?

S4D는 State Space Model, 줄여서 SSM 계열 모델이다.

SSM은 아주 단순하게 말하면:

입력이 들어오면 내부 기억을 업데이트하고, 그 기억을 바탕으로 출력을 만드는 모델

이다.

논문에서 기본 SSM은 이렇게 나온다.

[
x'(t)=Ax(t)+Bu(t)
]

[
y(t)=Cx(t)
]

이 식을 처음 보면 헷갈릴 수 있다. 하나씩 보자.


3. 변수 하나씩 보기

(u(t)): 입력

(u(t))는 시간 (t)에서 들어오는 입력이다.

예를 들어 오디오라면:

u(0) = 0초 시점의 소리 크기
u(1) = 1초 시점의 소리 크기
...

딥러닝에서는 보통 연속시간보다는 discrete sequence를 쓰지만, 논문에서는 먼저 연속시간 식으로 설명한다.


(x(t)): 내부 기억, 상태

(x(t))는 모델 내부에 저장된 기억이다.

예를 들어 지금까지 본 입력을 요약해서 들고 있는 메모장이라고 생각하면 된다.

x(t) = t 시점까지 입력을 보고 모델 내부에 저장된 정보

(x'(t)): 상태의 변화율

여기가 중요하다.

[
x'(t)
]

다음 상태가 아니다.

(x(t))가 시간에 따라 얼마나 빠르게 변하는지를 나타내는 값이다.

즉:

x(t)  = 현재 기억
x'(t) = 현재 기억이 지금 얼마나 변하고 있는지

예를 들어 온도로 비유하면:

x(t)  = 현재 온도
x'(t) = 온도가 초당 몇 도씩 변하는지

따라서 식

[
x'(t)=Ax(t)+Bu(t)
]

는 이렇게 읽어야 한다.

현재 기억의 변화율 =
기존 기억에 의해 생기는 변화
+
현재 입력 때문에 생기는 변화

(A): 기억이 스스로 변하는 방식

(A)는 기존 기억 (x(t))가 시간이 지나면서 어떻게 변하는지를 정한다.

예를 들어:

어떤 기억은 천천히 사라지고
어떤 기억은 빠르게 사라지고
어떤 기억은 진동하면서 남고
어떤 기억은 오래 유지된다

이 성질을 (A)가 결정한다.

S4D에서 가장 중요한 파라미터가 바로 이 (A)다.


(B): 입력을 기억에 넣는 방식

식에서:

[
Bu(t)
]

는 현재 입력 (u(t))가 내부 상태 (x(t))에 들어가는 부분이다.

즉:

B = 입력을 내부 기억 공간에 어떻게 넣을지 정하는 값

(C): 기억을 출력으로 읽는 방식

출력은:

[
y(t)=Cx(t)
]

이다.

즉 (C)는 내부 기억 (x(t))에서 어떤 정보를 읽어 출력 (y(t))로 만들지를 정한다.

C = 기억을 출력으로 읽는 방법

4. 먼저 순서대로 생각하기: 상태 업데이트 관점

연속시간 식은:

[
x'(t)=Ax(t)+Bu(t)
]

이다.

따라서 (t=0)에서는 이렇게 시작한다.

[
x'(0)=Ax(0)+Bu(0)
]

또는 간단히 쓰면:

[
x'_0=Ax_0+Bu_0
]

이게 정확한 출발점이다.

여기서 (x'_0)는 (x_1)이 아니다.
(x_0)에서의 변화율이다.

아주 작은 시간 간격 (\Delta t) 뒤의 상태는 대략:

[
x_1 \approx x_0+\Delta t \cdot x'_0
]

이다.

그러니까:

[
x_1 \approx x_0+\Delta t(Ax_0+Bu_0)
]

정리하면:

[
x_1 \approx (I+\Delta t A)x_0+\Delta t Bu_0
]

여기서 (I)는 identity matrix다.

즉 연속시간 식에서 바로

[
x_1=Ax_0+Bu_0
]

라고 하면 안 된다.

정확히는:

연속시간:
x'_0 = Ax_0 + Bu_0

작은 시간 뒤:
x_1 ≈ x_0 + Δt x'_0

이다.


5. Discretization: 연속시간을 컴퓨터용 식으로 바꾸기

딥러닝에서는 실제로 보통 이런 연속시간 데이터를 직접 쓰지 않는다.

대부분 입력은 discrete sequence다.

u_0, u_1, u_2, u_3, ...

그래서 연속시간 SSM을 discrete-time SSM으로 바꿔야 한다.

이 과정을 discretization이라고 한다.

연속시간 식:

[
x'(t)=Ax(t)+Bu(t)
]

을 discrete하게 바꾸면 보통 이렇게 쓴다.

[
x_{k+1}=\bar{A}x_k+\bar{B}u_k
]

[
y_k=Cx_k
]

여기서 중요한 점은:

[
\bar{A}, \bar{B}
]

는 원래의 (A,B)와 완전히 같은 것이 아니라, discretization을 거친 값이라는 것이다.

간단한 Euler 근사에서는:

[
\bar{A}=I+\Delta t A
]

[
\bar{B}=\Delta t B
]

정도가 된다.

논문에서는 Euler보다는 bilinear, ZOH 같은 discretization 방법을 다룬다. 하지만 초보자 관점에서는 “연속시간 식을 컴퓨터가 처리할 discrete 식으로 바꾼다” 정도로 이해하면 충분하다.


6. 이제 discrete 식을 순서대로 펼쳐보기

이제부터는 다음 discrete 식을 쓰자.

[
x_{k+1}=\bar{A}x_k+\bar{B}u_k
]

[
y_k=Cx_k
]

초기 상태는 단순하게:

[
x_0=0
]

이라고 하자.


Step 0

입력 (u_0)가 들어온다.

[
x_1=\bar{A}x_0+\bar{B}u_0
]

초기 상태 (x_0=0)이므로:

[
x_1=\bar{B}u_0
]

업데이트 후 출력은:

[
y_1=Cx_1
]

따라서:

[
y_1=C\bar{B}u_0
]


Step 1

다음 입력 (u_1)가 들어온다.

[
x_2=\bar{A}x_1+\bar{B}u_1
]

아까:

[
x_1=\bar{B}u_0
]

였으므로:

[
x_2=\bar{A}\bar{B}u_0+\bar{B}u_1
]

출력은:

[
y_2=Cx_2
]

이므로:

[
y_2=C\bar{A}\bar{B}u_0+C\bar{B}u_1
]


Step 2

다음 입력 (u_2)가 들어온다.

[
x_3=\bar{A}x_2+\bar{B}u_2
]

방금 (x_2)를 넣으면:

[
x_3=\bar{A}(\bar{A}\bar{B}u_0+\bar{B}u_1)+\bar{B}u_2
]

따라서:

[
x_3=\bar{A}^2\bar{B}u_0+\bar{A}\bar{B}u_1+\bar{B}u_2
]

출력은:

[
y_3=Cx_3
]

이므로:

[
y_3=C\bar{A}^2\bar{B}u_0+C\bar{A}\bar{B}u_1+C\bar{B}u_2
]


7. 여기서 중요한 사실: (y)는 결국 과거 입력들의 합이다

방금 식을 다시 보자.

[
y_3=C\bar{A}^2\bar{B}u_0+C\bar{A}\bar{B}u_1+C\bar{B}u_2
]

즉 (y_3)는 다음 입력들로 만들어진다.

u_0, u_1, u_2

처음에는 분명히 상태 (x)를 통해 계산했는데, (x)를 계속 펼쳐 쓰니까 결국 (y)가 과거 입력 (u)들의 조합으로 표현된다.

그래서 “(y)는 (u)의 함수다”라고 말할 수 있다.

더 정확히 말하면:

x는 과거 u들이 누적된 결과이고,
y는 그 x를 C로 읽은 것이므로,
y도 결국 과거 u들의 함수다.

8. Convolution kernel은 여기서 나온다

위 식에서 입력 앞에 붙은 계수를 보자.

[
C\bar{B}
]

[
C\bar{A}\bar{B}
]

[
C\bar{A}^2\bar{B}
]

이 값들을 하나의 리스트로 모으면:

[
K=(C\bar{B}, C\bar{A}\bar{B}, C\bar{A}^2\bar{B}, \dots)
]

이 (K)가 바로 convolution kernel이다.

그러면:

[
y_3=K_0u_2+K_1u_1+K_2u_0
]

처럼 쓸 수 있다.

일반적으로는:

[
y_t=K_0u_t+K_1u_{t-1}+K_2u_{t-2}+\cdots
]

이 된다.

이것이 convolution이다.

즉:

현재 출력 =
현재 입력 × 어떤 계수
+ 바로 이전 입력 × 어떤 계수
+ 더 이전 입력 × 어떤 계수
+ ...

논문에서도 discrete-time SSM의 convolution kernel을

[
K=(CB, CAB, \dots, CA^{L-1}B)
]

형태로 쓴다. 여기서는 notation을 단순화해서 discrete화된 (A,B)를 다시 (A,B)처럼 표기한다.


9. 숫자로 보는 아주 작은 예시

아주 단순하게 scalar 예시를 보자.

[
\bar{A}=0.5
]

[
\bar{B}=1
]

[
C=2
]

입력은:

u_0 = 10
u_1 = 20
u_2 = 30

초기 상태:

[
x_0=0
]


상태 업데이트로 계산

먼저:

[
x_1=0.5x_0+1u_0=10
]

[
y_1=2x_1=20
]

다음:

[
x_2=0.5x_1+u_1=0.5(10)+20=25
]

[
y_2=2x_2=50
]

다음:

[
x_3=0.5x_2+u_2=0.5(25)+30=42.5
]

[
y_3=2x_3=85
]


convolution으로 계산

kernel은:

[
K_0=C\bar{B}=2
]

[
K_1=C\bar{A}\bar{B}=2\cdot0.5\cdot1=1
]

[
K_2=C\bar{A}^2\bar{B}=2\cdot0.25\cdot1=0.5
]

따라서:

K = [2, 1, 0.5, ...]

그러면:

[
y_3=K_0u_2+K_1u_1+K_2u_0
]

[
=2\cdot30+1\cdot20+0.5\cdot10
]

[
=60+20+5=85
]

상태 업데이트로 계산한 결과와 같다.

즉 SSM은 두 가지 방식으로 볼 수 있다.

1. 상태를 하나씩 업데이트하는 recurrent 관점
2. 전체 입력에 convolution kernel을 적용하는 convolution 관점

둘은 같은 모델을 다르게 본 것이다.


10. 그럼 (K(t)=Ce^{tA}B)는 어디서 나오는가?

이제 연속시간으로 돌아가 보자.

연속시간 SSM은:

[
x'(t)=Ax(t)+Bu(t)
]

[
y(t)=Cx(t)
]

였다.

입력이 없으면:

[
x'(t)=Ax(t)
]

이 식의 해는:

[
x(t)=e^{tA}x(0)
]

이다.

여기서 (e^{tA})는 matrix exponential이다. 처음에는 이렇게 생각하면 된다.

e^{tA} = 상태를 t시간 동안 흘려보내는 연산자

즉 어떤 상태가 있을 때, (A)라는 규칙에 따라 (t)시간 뒤 어떻게 변하는지를 나타낸다.


입력이 한 번 들어온다고 생각하기

어떤 시점에 입력 (u)가 들어오면, 먼저 (B)를 통해 상태 공간에 들어간다.

u → B → 내부 상태에 주입

그다음 시간이 (t)만큼 지나면 이 상태는 (A)에 의해 변한다.

B u → e^{tA}B u

마지막으로 (C)가 이 상태를 출력으로 읽는다.

C e^{tA} B u

따라서 입력 하나가 들어온 뒤 (t)시간 후 출력에 주는 영향력은:

[
Ce^{tA}B
]

이다.

이걸 kernel로 정의한다.

[
K(t)=Ce^{tA}B
]

즉 이 식은 이렇게 읽으면 된다.

B:
입력을 상태에 넣는다.

e^{tA}:
그 상태가 t시간 동안 어떻게 변하는지 계산한다.

C:
변한 상태를 출력으로 읽는다.

그래서:

[
K(t)=Ce^{tA}B
]

는 “입력이 들어온 후 (t)시간 뒤 출력에 얼마나 영향을 주는가”를 나타낸다. 논문도 continuous SSM을 convolution 형태로 쓸 수 있고, kernel을 (K(t)=Ce^{tA}B)로 정의한다.


11. 왜 convolution 관점이 중요한가?

상태 업데이트 방식은 직관적이다.

x_0 → x_1 → x_2 → x_3 → ...

하지만 단점이 있다.

x_2를 계산하려면 x_1이 필요하고
x_3을 계산하려면 x_2가 필요하고
x_4를 계산하려면 x_3이 필요하다

즉 병렬화가 어렵다.

반면 convolution 관점에서는:

1. kernel K를 만든다.
2. 입력 u와 K를 convolution한다.

이렇게 전체 시퀀스를 한 번에 처리할 수 있다.

GPU에서는 convolution을 빠르게 병렬 계산할 수 있다.
그래서 S4 계열 모델은 학습할 때 convolution 관점을 매우 중요하게 사용한다.


12. 이제 S4D로 넘어가기

기존 S4도 SSM이다.

즉 기본적으로:

[
x'(t)=Ax(t)+Bu(t)
]

[
y(t)=Cx(t)
]

를 사용하고, 이를 convolution 형태로 계산한다.

그런데 기존 S4의 문제는 (A)가 복잡하다는 것이다.

S4는 긴 시퀀스를 잘 처리하기 위해 HiPPO라는 특별한 행렬을 사용했다. 이 행렬은 long-range dependency를 잘 처리하게 해주지만, 계산과 구현이 어렵다. 논문은 S4가 DPLR, 즉 diagonal plus low-rank 구조를 사용했고, 이 때문에 복잡한 선형대수 알고리즘이 필요했다고 설명한다.

대략적으로 말하면 기존 S4의 (A)는 이런 느낌이다.

[
A = \text{diagonal part} + \text{low-rank correction}
]

이 구조는 강력하지만 구현이 어렵다.


13. S4D의 핵심 아이디어

S4D는 아주 과감한 질문을 던진다.

low-rank correction을 없애고, (A)를 diagonal로만 두면 안 될까?

즉:

S4:
A = diagonal + low-rank

S4D:
A = diagonal only

diagonal 행렬은 이렇게 생겼다.

[
A=
\begin{bmatrix}
a_1 & 0 & 0 \
0 & a_2 & 0 \
0 & 0 & a_3
\end{bmatrix}
]

이렇게 되면 각 state dimension이 서로 섞이지 않는다.

x_1은 x_1끼리만 업데이트
x_2는 x_2끼리만 업데이트
x_3은 x_3끼리만 업데이트

즉 S4D는 여러 개의 독립적인 작은 기억 장치를 모아놓은 것처럼 볼 수 있다.

논문 Figure 1에서도 S4D를 여러 개의 독립적인 1차원 SSM들의 모음처럼 설명한다. 또한 diagonal 구조 덕분에 convolution kernel을 매우 단순하게 계산할 수 있다고 강조한다.


14. Diagonal이면 왜 계산이 쉬워질까?

일반적인 행렬 (A)는 (A^2, A^3, e^{tA}) 같은 계산이 어렵다.

하지만 (A)가 diagonal이면 아주 쉽다.

예를 들어:

[
A=
\begin{bmatrix}
a_1 & 0 \
0 & a_2
\end{bmatrix}
]

이면:

[
A^2=
\begin{bmatrix}
a_1^2 & 0 \
0 & a_2^2
\end{bmatrix}
]

[
A^3=
\begin{bmatrix}
a_1^3 & 0 \
0 & a_2^3
\end{bmatrix}
]

즉 각 diagonal 원소만 거듭제곱하면 된다.

그래서 discrete kernel:

[
K_\ell=C A^\ell B
]

도 diagonal (A)에서는 이렇게 풀린다.

[
K_\ell=\sum_{n=0}^{N-1} C_n A_n^\ell B_n
]

논문은 diagonal SSM의 kernel 계산이 Vandermonde matrix-vector multiplication으로 단순하게 표현된다고 설명한다.


15. S4D의 kernel 계산 예시

기억 장치가 2개 있다고 하자.

[
A_1=0.8
]

[
A_2=0.3
]

[
B_1=1,\quad B_2=1
]

[
C_1=2,\quad C_2=1
]

그러면:

[
K_0=C_1A_1^0B_1+C_2A_2^0B_2
]

[
=2\cdot1\cdot1+1\cdot1\cdot1=3
]

다음:

[
K_1=C_1A_1^1B_1+C_2A_2^1B_2
]

[
=2\cdot0.8+1\cdot0.3=1.9
]

다음:

[
K_2=C_1A_1^2B_1+C_2A_2^2B_2
]

[
=2\cdot0.64+1\cdot0.09=1.37
]

따라서:

K = [3, 1.9, 1.37, ...]

이 (K)를 입력 시퀀스와 convolution하면 출력이 나온다.


16. S4D에서 (A)는 무엇을 의미하나?

S4D에서 diagonal (A)의 각 원소 (A_n)은 하나의 기억 필터를 만든다.

복소수 형태로 보면 보통:

[
A_n = \text{real part} + i \cdot \text{imaginary part}
]

처럼 쓴다.


Real part: 기억이 얼마나 빨리 사라지는가

real part가 음수이면 기억이 시간이 지나면서 줄어든다.

많이 음수 → 빨리 잊음
0에 가까움 → 천천히 잊음

S4D에서는 보통 real part를 안정적으로 음수로 둔다.


Imaginary part: 어떤 주파수로 흔들리는가

imaginary part는 진동 주파수와 관련 있다.

작은 imaginary part → 느린 패턴
큰 imaginary part → 빠른 패턴

즉 S4D는 여러 (A_n)을 사용해서 다양한 패턴을 본다.

어떤 필터는 짧은 패턴을 보고
어떤 필터는 긴 패턴을 보고
어떤 필터는 주기적인 패턴을 본다

논문은 diagonal SSM에서 (A_n)의 real part는 decay rate를, imaginary part는 oscillating frequency를 조절한다고 설명한다.


17. S4D-Lin 예시

논문에서 제안하는 간단한 initialization 중 하나가 S4D-Lin이다.

[
A_n=-\frac12+i\pi n
]

이 식을 해석하면:

-\frac12:
모든 기억 장치가 적당히 감소하도록 만든다.

iπn:
각 기억 장치가 서로 다른 주파수로 진동하게 만든다.

예를 들어:

A_0 = -0.5 + i·0
A_1 = -0.5 + i·π
A_2 = -0.5 + i·2π
A_3 = -0.5 + i·3π
...

이렇게 하면 여러 주파수의 필터를 준비해두는 효과가 있다.

논문은 S4D-Lin을 damped Fourier basis처럼 볼 수 있다고 설명한다.


18. 중요한 점: 아무 diagonal (A)나 쓰면 안 된다

여기서 아주 중요한 포인트가 있다.

S4D가 diagonal이라고 해서 아무 (A)나 써도 되는 것은 아니다.

논문의 핵심 주장 중 하나는:

diagonal SSM이 잘 되려면 (A)의 initialization이 매우 중요하다.

라는 것이다.

랜덤한 diagonal matrix를 쓰면 성능이 좋지 않다.
반대로 HiPPO 구조에서 유도된 diagonal initialization이나, S4D-Inv, S4D-Lin 같은 잘 설계된 initialization을 쓰면 좋은 성능을 낸다.

논문은 DSS가 사용한 diagonal approximation이 왜 잘 되는지도 이론적으로 설명한다. 특히 HiPPO-LegS matrix의 low-rank 부분을 제거한 diagonal approximation이 state size가 커질수록 원래 S4와 비슷한 dynamics를 만든다는 점을 보인다.

비유하면 이렇다.

S4D의 A는 악기 줄 조율과 비슷하다.

diagonal 구조는 악기 자체를 단순하게 만든 것이고,
initialization은 줄을 올바른 음에 맞추는 것이다.

악기가 단순해도 조율이 엉망이면 좋은 소리가 나지 않는다.

19. S4D의 전체 흐름 정리

이제 S4D의 방법론을 처음부터 끝까지 정리해보자.

Step 1. 입력 sequence를 받는다

u = [u_0, u_1, u_2, ..., u_L]

Step 2. SSM 파라미터를 준비한다

A: 기억이 어떻게 변하는지
B: 입력을 기억에 어떻게 넣는지
C: 기억을 출력으로 어떻게 읽는지

S4D에서는 (A)를 diagonal로 둔다.


Step 3. 연속시간 파라미터를 discrete하게 바꾼다

연속시간 식:

[
x'(t)=Ax(t)+Bu(t)
]

을 실제 sequence에 맞게:

[
x_{k+1}=\bar{A}x_k+\bar{B}u_k
]

로 바꾼다.


Step 4. kernel을 만든다

diagonal SSM에서는 kernel이 간단하다.

[
K_\ell=\sum_n C_n A_n^\ell B_n
]

이것이 S4D의 핵심 계산이다.


Step 5. 입력과 convolution한다

[
y=K*u
]

즉:

현재 출력은 현재 입력과 과거 입력들의 가중합

이다.


20. S4D를 한 문장으로 말하면

S4D는:

복잡한 S4의 state matrix를 diagonal로 단순화하고, 잘 설계된 initialization을 사용해 긴 시퀀스를 효율적으로 처리하는 SSM 모델

이다.

조금 더 직관적으로 말하면:

S4D는 여러 개의 독립적인 기억 필터를 만들고,
각 필터가 과거 입력의 영향을 얼마나 오래, 어떤 주파수로 기억할지 정한 뒤,
그 결과를 convolution kernel로 만들어 입력 sequence에 적용하는 모델이다.

21. S4D를 비유로 이해하기

S4D를 “여러 개의 메아리 필터”라고 생각해보자.

동굴에서 소리를 냈다.

"안녕!"

그러면 여러 종류의 메아리가 돌아온다.

짧게 남는 메아리
오래 남는 메아리
빠르게 흔들리는 메아리
느리게 흔들리는 메아리

S4D의 각 diagonal state는 이런 메아리 하나에 해당한다.

A_n:
메아리가 얼마나 빨리 사라지고, 어떤 식으로 흔들리는지

B_n:
입력 소리가 그 메아리 필터에 얼마나 들어가는지

C_n:
그 메아리를 최종 출력에 얼마나 섞을지

K:
전체 메아리 패턴

y:
최종적으로 들리는 출력

이렇게 보면 S4D는 복잡한 attention 없이도 긴 과거 정보를 기억할 수 있다.


22. 기존 S4와 S4D의 차이

구분S4S4D
(A) 구조Diagonal + Low-rankDiagonal only
계산복잡함단순함
kernel 계산Cauchy 기반 복잡한 알고리즘Vandermonde / 거듭제곱 기반
구현 난이도높음낮음
성능매우 강함거의 비슷하게 강함
핵심 조건HiPPO 구조좋은 diagonal initialization

논문은 S4D가 S4보다 단순하면서도 Long Range Arena, 이미지, 오디오, medical time-series 등에서 강력한 성능을 낸다고 보고한다.


23. 최종 요약

처음부터 다시 요약하면 다음과 같다.

  1. SSM은 내부 기억 (x(t))를 가진다.

  2. (x'(t)=Ax(t)+Bu(t))는 기억의 변화율을 나타낸다.

  3. 연속시간 식을 discrete하게 바꾸면 (x_{k+1}=\bar{A}x_k+\bar{B}u_k)가 된다.

  4. 이 식을 계속 펼치면 (y)는 과거 입력 (u)들의 가중합으로 표현된다.

  5. 그 가중합의 계수들이 convolution kernel (K)다.

  6. 연속시간에서는 그 kernel이 (K(t)=Ce^{tA}B)로 나온다.

  7. 기존 S4는 (A)가 복잡해서 구현이 어렵다.

  8. S4D는 (A)를 diagonal로 만들어 계산을 매우 단순화한다.

  9. 단, (A)를 아무렇게나 초기화하면 안 되고, 주파수와 decay 구조를 잘 잡아야 한다.

  10. 그래서 S4D는 단순하지만 강력한 long-sequence 모델이 된다.

한 줄로 끝내면:

S4D는 “상태를 업데이트하는 모델”을 “긴 convolution filter를 만드는 모델”로 바라보고, 그 filter를 diagonal (A) 덕분에 아주 쉽게 계산하는 방법이다.

맞아. SSM의 기본 파라미터는 (A, B)만이 아니야.
최소한 기본 SSM에는 보통 (A, B, C) 가 있어.

기본 식은:

[
x'(t)=Ax(t)+Bu(t)
]

[
y(t)=Cx(t)
]

여기서 파라미터는:

파라미터역할
(A)내부 상태 (x)가 시간에 따라 어떻게 변하는지
(B)입력 (u)를 내부 상태 (x)에 어떻게 넣는지
(C)내부 상태 (x)를 출력 (y)로 어떻게 읽는지

즉, (A, B)는 상태를 업데이트하는 데 필요한 파라미터고,
(C)는 출력을 만들기 위한 파라미터야.


예를 들어 SSM을 “메모장”으로 보면:

[
x'(t)=Ax(t)+Bu(t)
]

는 메모장을 업데이트하는 규칙이야.

A: 기존 메모가 스스로 어떻게 변하는가
B: 새 입력을 메모장에 얼마나 적는가

그런데 메모장에 적힌 내용을 최종 답으로 읽어야 하잖아?

그게:

[
y(t)=Cx(t)
]

이고, 여기서 (C)가 필요해.

C: 메모장 내용을 어떻게 읽어서 출력으로 만들 것인가

기본 SSM의 파라미터

가장 기본 형태에서는:

[
A,\ B,\ C
]

가 파라미터야.

조금 더 일반적인 SSM에서는 (D)도 들어간다.

[
x'(t)=Ax(t)+Bu(t)
]

[
y(t)=Cx(t)+Du(t)
]

여기서 (D)는 입력을 바로 출력으로 보내는 skip connection 같은 역할을 해.

D: 현재 입력을 내부 상태를 거치지 않고 출력에 바로 반영하는 값

그래서 더 일반적으로는:

[
A,\ B,\ C,\ D
]

를 SSM 파라미터라고 볼 수 있어.


S4D에서는 뭐가 파라미터인가?

S4D에서는 기본적으로 다음이 중요해.

[
A,\ B,\ C,\ \Delta
]

그리고 구현에 따라 (D)도 있다.

파라미터S4D에서의 의미
(A)diagonal state matrix. 기억의 decay와 주파수 결정
(B)입력을 각 state에 넣는 weight
(C)state들을 섞어서 출력으로 읽는 weight
(\Delta)continuous-time SSM을 discrete sequence로 바꾸는 time step
(D)입력을 출력에 바로 더하는 skip 계수

논문에서도 SSM의 파라미터를 (A \in \mathbb{C}^{N \times N}), (B \in \mathbb{C}^{N \times 1}), (C \in \mathbb{C}^{1 \times N})로 설명한다.


제일 중요한 정리

[
x'(t)=Ax(t)+Bu(t)
]

만 보면 (A,B)만 있는 것처럼 보이지만, 이건 상태가 어떻게 변하는지만 말하는 식이야.

출력을 만들려면 반드시:

[
y(t)=Cx(t)
]

가 필요해.

그래서 SSM의 핵심 파라미터는:

[
\boxed{A,\ B,\ C}
]

이고, 실제 딥러닝 SSM/S4D layer에서는 보통:

[
\boxed{A,\ B,\ C,\ \Delta,\ D}
]

까지 생각하면 돼. 

Reference

댓글