Meta-003, Meta-Learning with Memory-Augmented Neural Networks (2016-JMLR)
0. Abstract
- 기존의 GD 학습은 많은 데이터를 필요로 광범위한 학습을 진행하는 식이다.
- 새로운 데이터가 들어오면, 모델은 비효율적으로 다시 파라미터를 재학습해야한다. (catastrophic inference 없이)
- Neural Turing Machines (NTMs)은 빠르게 encode하여 새로운 정보를 검색해서 conventional model의 단점을 줄이는 경향이 있다.
- 이 논문에서는 memory-agumented neural network을 제안하여 빠르게 새로운 데이터를 동화시키고 몇 개의 샘플만을 이용하여 정확한 예측을 하는 것을 보여준다.
- 또한 기존의 방법들과 같이 외부의 메모리에 access 한 방법을 소개한다.
- 여기까지만 봤을 때는, continuous learning에서 말하는 것과 같은 문제점을 제시한 것 같음.
- 즉 데이터가 조금 있는 상황 + 이미 학습한 모델에서 새로운 데이터가 들어온 상황 두 가지를 고려했을 때의 문제점을 해결하고자 하는 것 같다.
1. Introduction
- 현재의 딥러닝 성공은 (2017년) 많은 광범위한 데이터로 학습하고 성능을 측정하는 식이다.
- 반면에 많은 문제의 관심은 적은 데이터로부터 빠른 inference을 원한다.
- One-shot learning의 한계는 1개의 관찰 데이터로부터 이러한 행동을 해야한다는 점이다.
- 한 두개의 context에 등장하는 정보만으로 기계가 학습하는 것은 어렵다.
- 특히 딥러닝은 더욱더 쉽지 않은다.
- 몇 개의 샘플 one-by-one으로 주어지면 gradient-based로 파라미터를 완전 재학습을(랜덤 initialization) 해야한다.
- 이렇게 하면 poor learning이 될 경향이 높고 catastrophic inference가 될 것이다.
- 이러한 위험 관점에서는 non-parametric 방법이 더 나을 수가 있다.
- 이전 연구들에서는 meta-learning은 sparse data에서 빠르게 학습하는 방법에 대한 전략을 의미했었다.
- Meta-learning은 일반적으로 2개 levels에서 agent 학습 시나리오을 따른다.
- 각각 level은 time scales가 다르다.
- Task안에서 Rapid learning은 발생하는데, 예를 들면 특정 dataset을 분류를 학습하는 것이다.
- 이 학습은 도메인 task의 변화에 따라서 점진적으로 학습의 가이드가 발생한다.
- 주어진 2가지 구성에서, meta-learning은 종종 "learning to learn"으로 묘사된다.
- Memory capacity을 가지는 Neural network는 meta-learning을 할 수 있다는 것이 증명되어 왔다.
- 이러한 네트워크들은 bias을 weight updates로 변경시킬 뿐 아니라, memory stores안의 cache representations을 빠르게 학습함으로써 출력을 조정한다.
- 예를 들어, meta-learning으로 학습된 LSTMs은 적은 데이터로 구성된 이전에 모르던 2차 함수에 대해 빠르게 학습할 수 있다.
- Memory capacity을 가지는 뉴럴 네트워크느 meta-learning에 적절하다.
- 그러나 이러한 memory을 사용한 RNN 구조의 전략은 새로운 task가 원하는 방대한 새로운 정보의 양을 빠르게 인코딩하는데 적절하지 않다.
- 쉽게 생각해서 RNN을 사용하면 memory의 정보를 가져갈 수 있지만, RNN은 학습과 inference도 빠르지 않기 때문에 적은 정보에 적합한 네트워크가 아니란 것 같은데??
- 이러한 문제를 해결하기 위해 다음의 조건이 필요하다.
- 메모리에 저장된 정보가 stable하고 element-wise addressable 해야 한다.
- Stable: 필요할 때, 안정적으로 접근하기 위함
- Element-wise addressable: 정보와 관련된 일부분을 선택적으로 접근하기 위함
- 파라미터의 수가 메모리 사이즈에 묶여있으면 안된다.(의존적이면 안된다?)
- 이 두 가지 특성은 LSTM에서 자연스럽게 일어나지 않는다.
- 그러나 NTM이나 Memory networks은 이러한 기준을 만족시킨다.
- 따라서 이 논문에서는, meta-learning 문제를 재정의 하고 memory-augmented neural network (MANN) 관점에서 setup을 한다.
- 여기서 MANN은 외부 메모리를 참조한다는 것이고 LSTM처럼 내부의 메모리를 이용한다는 것이 아니다.
- MANNs은 meta-learning tasks에서 short / long term memory 요구에 잘 수행한다.
- 본적 없는 Omniglot classes 분류 테스크에서 적은 샘플로만으로 사람과 같은 정확도를 보여준다.
- 기존 NTM과 달리 여기서의 memory access module은 content에 접근할 수 있음을 말한다. (memory location의 추가가 없음)
- 논문의 접근법은 두 가지 best을 결합한 것이다.
- GD 학습을 통해 useful representation을 천천히 학습하는 것
- 외부 memory module을 통한 한 번의 presentation 후, 전에 본 적 없는 정보를 빠르게 습득
- 이러한 조합은 meta-learning에 강력하고 많은 범위의 문제에 확장이 가능하다.
- NTM의 논문을 읽은 적이 없어 블로그를 찾아서 봤지만 어렵...
2. Meta-Learning Task Methodology
- 보통 some dataset D에 대해서 cost L을 최소화하도록 parameters 을 학습한다.
- 하지만 meta-learning에서는 parameters을 dataset p(D)의 분포에서 expected cost을 줄이는 방식으로 한다. (뭐가 다른거지? 기댓값이라는 개념?)
- 이 논문의 셋업 task에서 다음과 같이 데이터세트가 있다고 하자.
- Classification에서는 이미지 에 대한 가 class label이다.
- Regression에서는 real 값 에 대한 hidden function의 출력이 이다.
- 즉 두 가지 task 모두 입력이 이고 가 target이며 temporally offset 방식으로 가 주어진다.
- 이 뜻은 처럼 입력 sequence가 주어진다는 것이다.
- (a)와 같이 입력이 구성이 된다는 것인데 현재 time-step에서의 입력에 이전의 time-step에서의 correct label이 같이 들어오게 되는 식이다. (마치 language model에서 teacher-forcing 느낌?)
- 모델 그림을 쉽게 LSTM이라 생각하면 될듯
- 여기서 중요한 점은 labels are shuffled from dataset-to-dataset,
- 예) class 5개에 대해 데이터를 10개씩 모은다.
- 50 개의 데이터를 섞은다는 것 같음
- 따라서 네트워크가 제대로 된 학습을 못하지만..? 다음 step에서 제대로 된 정답을 줌으로써 데이터 샘플의 메모리를 다음의 step까지 적절히 가져가야하는 점을 배운다.
- 궁극적으로 시스템이 초점을 두는 모델링은 다음과 같다.
- 각 time step에서의
- 보통 classification할 때, 에 해당하는 출력이 가 되도록 학습하지만,
- 여기서는 t-1에 해당하는 데이터 집합에 대해 기억(메모리)을 가져간다는 것.
- 이것이 가능한 이유는 correct label을 다음 step에서 제시하기 때문
- 즉 figure 1의 (b)에서 t+1 step에서 에 해당하는 값이 2로 external 메모리로 가져간다.
- 그리고 t+n step에서 때와 같은 class 입력이 들어왔을 때, t step에서 메모리에 write한 것(external memory)을 이용하면 된다 같은 개념?
- 아직도 이해가 안가는 부분이 많으나 뒷 부분을 읽어보자...
3. Memory-Augmented Model
3.1 Neural Turing Machines
- NTM에는 controller가 있는데 feed-forward network 혹은 LSTM와 같은 것을 말한다.
- 많은 수의 read와 write heads을 이용한 이는 external memory와 상호작용을 한다.
- Memory encoding과 NTM external memory 모듈에서 retrieval은 빠르다.
- vector representation은 대체되거나 매 time-step에서 잠재적으로 제거한다.
- 이 기능은 NTM을 meta-learning과 low-shot prediction에 완벽하게 만든다.
- 느린 weights 업데이트를 통한 long-term storage
- external memory 모듈을 통한 short-term storage
- 만약 NTM이 메모리에 씌여질 representation에 대한 general strategy을 배우고 이를 prediction에 쓸 수 있다면, only 한 번 본 것에 대해서 정확한 예측을 수행할 수 있다.
- verification 방식 처럼을 의미하는 건가?
- Controllers은 앞에서 말했듯이 LSTM 또는 feed-forward networks이다.
- 이것은 read와 write을 사용하여 external memory module과 상호작용을 한다.
- read와 write는 각각 memory에서 representations을 찾거나 memory에 쓰는 것을 말한다.
- 주어진 입력이 이고 이에 대해서 controller가 key 을 생성했다고 하자.
- key는 memory matrix 에 저장되거나 particular memory, 예로 i번째 열을 찾는데 사용된다.
- 메모리를 검색할 때, 는 cosine similarity 측정 방법을 통해 사용된다.
- 즉 어떤 입력에 대해서 생성된 key(feature?)가 기존의 memory에 비슷한게 있으면 그것을 출력으로 하고 아니면 새롭게 memory에 쓴다는 것 같음.
- 정확히는 다음과 같은 수식을 따른다.
- 즉 식 (3)에서 key와 memory들과의 가중치를 구하고 의 가중치 합을 memory 으로 사용한다.
- controller부터 사용된 메모리는 softmax output layer와 같이 classifier의 입력으로 사용된다.
- 또한 추가적으로 next controller state을 위한 추가적인 입력으로도 사용된다.
- 위 말은 LSTM에서의 hidden state vector처럼 출력을 위해서도 사용되고 다음 cell에도 입력으로 들어간다는 개념인 듯
- 그런데 그러면 언제 memory에 write하는 것에 대해서는 3.2 절에서
3.2 Least Recently Used Access
- NTM의 이전 instantiations은 content와 location 모두에 의해 결정되었다.
- Location-based은 memory을 가로질러 long-distance을 넘나들 뿐 아니라 tape을 따라서 달리는 것과 비슷하다.
- sequence-based 예측 task에서 유리한 방법이다.
- 이 방법은 conjunctive coding of information independent of sequence에 적합하지 않다.
- 쉽겟 seq2seq에서 쓰는 attention 기법은 sequential한 입력이기 때문에 location-based 방법을 써도 되는데, 여기서 입력은 sequential 한 개념이 아니기 때문에 LRUA을 사용한다는 것 같은데?
- 따라서 LRUA을 제시한다.
- LRUA는 content-based memory writer이다.
- least used memory location (가장 적게 사용한 메모리 위치) 혹은 most recently used memory location (가장 최근 사용한 메모리 위치) 둘 중 한개를 write하는 것
- This module emphasizes accurate encoding of relevant (i.e., recent) information, and pure content-based retrieval.
- 새로운 정보가 오면
- rarely-used location에 (즉 거의 안쓰던 위치) 인코딩된 정보를 보존 (업데이트를 하는 것)
- last used location에 (가장 최근 사용하던 위치) 메모리를 새로운 것으로 업데이트를 한다. (이 방법은 가장 연관성이 높을 위치이기 때문에)
- 이 두가지 옵션은 previous read weights (이전의 read weights) 와 weights scaled according to usage weights (사용한 weight에 따라 scaled된)에 따라 interpolation이 된다.
- 이러한 weight의 사용은 매 time-step에서 previous usage weights을 업데이트하고 current read와 write weight을 다음과 같이 추가한다.
- u는 usage, r은 read, w는 write, 은 decay parameter이다.
- 즉 이전 step t-1에서 usage한 weight에다가 read와 write할 weight을 더해주는 것이 현재 step t에서의 usage weight이다.
- read weight는 3.1 절에서 식 (3)에서 설명을 했고 write weight는 밑에서 설명 (+이를 위한 weight_lu도 설명)
- 개념은 사용한 weight이기 때문에 이전 step에서 decay을 취해주고 read, write weight을 사용했으니까 더해주는 느낌
- least-used weights,
- 즉 에서 n번째로 작은 항보다 가 크냐 작냐에 따라 0또는 1의 값을 가진다.
- 그렇다면 n이라는 숫자는? memory을 읽은 숫자로 세팅한다. 무슨 의미지?
- 즉 보다 작다는 것은 잘 사용하지 않았다는 것(least used)이고 이에 해당하는 값은 1로 두어서 write weight을 구할 때, 새롭게 쓰겠다! - 1) rarely-used location 방법
- write weight 을 얻기 위해서 learnable sigmoid gate parameter을 사용하여 previous read weight와 leas-used weight을 convex combination한다.
- 여기서 sigma가 sigmoid function이고 alpha는 scalar gate parameter로 interpolate을 결정하는 것이다.
- 을 사용하는 것은 식 6에서도 설명했지만 이는 LRUA에서 새로운 정보가 왔을 때 write하는 위치를 결정하는 1)번 방법을 의미하고
- 을 더해주는 것은 이전 step t-1에서 사용한 memory을 더해주는 것으로 2)번 방법 last used location을 의미한다.
- write weight 을 이용하여 최종 write는 다음과 같다.
- Thus, memories can be written into the zeroed memory slot or the previously used slot; if it is the latter, then the least used memories simply get erased.
- 따라서 메모리는 1) zeroed memory slot에 write가 될 수도 2) 이전에 사용된 slot에 write일 수도 있는데 2)일 경우는 least used memory가 사라지는 형태이다.
- 정리하자면)
- 새로운 정보가 오면, 이를 이용하여 read도 하고 write도 한다.
- 어디다 read를 하고 write을 하느냐? M이라는 matrix을 이용한다.
- Matrix의 각 열들이 정보를 담고 있는 것인데..
- read: 입력에 대해 key k을 만들어서 M 열들과 cosine softmax을 취해서 비슷한 만큼 열의 weighted sum으로 메모리를 읽는다.
- write: write weight을 만들어서 M 각 열에 key k을 weighted 한 것을 더해준다.
4. Experimental Results
4.1 Data
- 두 가지 데이터 세트를 사용하였다.
- Omniglot for classification
- Gaussian process(GP)에서 샘플링한 데이터 for regression
- Omniglot
- 1600 class
- class당 few examples
- MNIST 급으로 유명
- Overfitting을 줄이기 위하여 randomly translating and rotating
- 60도, 180도, 270도 회전을 통하여 새로운 class을 만들었다.
- 학습 데이터로는 1200개의 기존 클래스를 (+augmentation) 이용하였다.
- 나머지 425 클래스는 (+augmentation) 테스트에 사용하였다.
- 즉 학습은 1200+ / 테스트는 425+ 클래스를 사용했단 거겠지?
- 이미지는 20x20으로 다운스케일하여 사용하였다.(계산 복잡성 때문)
4.2 Omniglot Classification
- 그림 2
- 학습 episode는 100,000개로 5개의 랜덤 선택된 class와 랜덤 선택된 labels이다.
- 그리고 네트워크에 한 번도 본 적 없는 test episode가 주어진다.
- 테스트 episode으로는 더 이상 학습이 일어나지 않고 한 번도 본 적 없는 class labels을 예측한다.
- 높은 정확도를 보여준다. 82.8%
- 5번째 instance는 94.9%까지 달성하고 10번째 instance는 98.1%까지 달성
- instance가 머지??
- 그리고 사람이 성능을 평가하는 방법에 대해서 설명되어 있으나 생략
논문 실험 부분을 읽다 보니 이해가 안되는 부분이 너무 많았다.
이를 다 이해하는 것은 비효율적이라 생각되어 시간이 오래 걸려 뒷 부분은 실험 결과만 첨부하고 생략해야겠다....
이를 다 이해하는 것은 비효율적이라 생각되어 시간이 오래 걸려 뒷 부분은 실험 결과만 첨부하고 생략해야겠다....
MANN은 external memory을 이용한 model-based의 가장 기초고 이 논문에서 파생된 스트림을 나중에 포스팅을 하면서 좀 더 자세히 이해해보자.
4.x Figures
Reference
댓글
댓글 쓰기