Meta-003, Meta-Learning with Memory-Augmented Neural Networks (2016-JMLR)

0. Abstract

  • 기존의 GD 학습은 많은 데이터를 필요로 광범위한 학습을 진행하는 식이다.
  • 새로운 데이터가 들어오면, 모델은 비효율적으로 다시 파라미터를 재학습해야한다. (catastrophic inference 없이)
  • Neural Turing Machines (NTMs)은 빠르게 encode하여 새로운 정보를 검색해서 conventional model의 단점을 줄이는 경향이 있다.
  • 이 논문에서는 memory-agumented neural network을 제안하여 빠르게 새로운 데이터를 동화시키고 몇 개의 샘플만을 이용하여 정확한 예측을 하는 것을 보여준다.
  • 또한 기존의 방법들과 같이 외부의 메모리에 access 한 방법을 소개한다.
  • 여기까지만 봤을 때는, continuous learning에서 말하는 것과 같은 문제점을 제시한 것 같음.
  • 즉 데이터가 조금 있는 상황 + 이미 학습한 모델에서 새로운 데이터가 들어온 상황 두 가지를 고려했을 때의 문제점을 해결하고자 하는 것 같다.

1. Introduction

  • 현재의 딥러닝 성공은 (2017년) 많은 광범위한 데이터로 학습하고 성능을 측정하는 식이다.
  • 반면에 많은 문제의 관심은 적은 데이터로부터 빠른 inference을 원한다.
  • One-shot learning의 한계는 1개의 관찰 데이터로부터 이러한 행동을 해야한다는 점이다.
  • 한 두개의 context에 등장하는 정보만으로 기계가 학습하는 것은 어렵다.
  • 특히 딥러닝은 더욱더 쉽지 않은다.
    • 몇 개의 샘플 one-by-one으로 주어지면 gradient-based로 파라미터를 완전 재학습을(랜덤 initialization) 해야한다.
    • 이렇게 하면 poor learning이 될 경향이 높고 catastrophic inference가 될 것이다.
    • 이러한 위험 관점에서는 non-parametric 방법이 더 나을 수가 있다.
  • 이전 연구들에서는 meta-learning은 sparse data에서 빠르게 학습하는 방법에 대한 전략을 의미했었다.
    • Meta-learning은 일반적으로 2개 levels에서 agent 학습 시나리오을 따른다.
    • 각각 level은 time scales가 다르다.
    • Task안에서 Rapid learning은 발생하는데, 예를 들면 특정 dataset을 분류를 학습하는 것이다.
      • 이 학습은 도메인 task의 변화에 따라서 점진적으로 학습의 가이드가 발생한다. 
    • 주어진 2가지 구성에서, meta-learning은 종종 "learning to learn"으로 묘사된다.
  • Memory capacity을 가지는 Neural network는 meta-learning을 할 수 있다는 것이 증명되어 왔다.
    • 이러한 네트워크들은 bias을 weight updates로 변경시킬 뿐 아니라, memory stores안의 cache representations을 빠르게 학습함으로써 출력을 조정한다.
    • 예를 들어, meta-learning으로 학습된 LSTMs은 적은 데이터로 구성된 이전에 모르던 2차 함수에 대해 빠르게 학습할 수 있다.
  • Memory capacity을 가지는 뉴럴 네트워크느 meta-learning에 적절하다.
    • 그러나 이러한 memory을 사용한 RNN 구조의 전략은 새로운 task가 원하는 방대한 새로운 정보의 양을 빠르게 인코딩하는데 적절하지 않다.
    • 쉽게 생각해서 RNN을 사용하면 memory의 정보를 가져갈 수 있지만, RNN은 학습과 inference도 빠르지 않기 때문에 적은 정보에 적합한 네트워크가 아니란 것 같은데??
  • 이러한 문제를 해결하기 위해 다음의 조건이 필요하다.
    1. 메모리에 저장된 정보가 stable하고 element-wise addressable 해야 한다.
      • Stable: 필요할 때, 안정적으로 접근하기 위함
      • Element-wise addressable: 정보와 관련된 일부분을 선택적으로 접근하기 위함
    2. 파라미터의 수가 메모리 사이즈에 묶여있으면 안된다.(의존적이면 안된다?)
  • 이 두 가지 특성은 LSTM에서 자연스럽게 일어나지 않는다.
    • 그러나 NTM이나 Memory networks은 이러한 기준을 만족시킨다.
    • 따라서 이 논문에서는, meta-learning 문제를 재정의 하고 memory-augmented neural network (MANN) 관점에서 setup을 한다.
    • 여기서 MANN은 외부 메모리를 참조한다는 것이고 LSTM처럼 내부의 메모리를 이용한다는 것이 아니다.
  • MANNs은 meta-learning tasks에서 short / long term memory 요구에 잘 수행한다.
    • 본적 없는 Omniglot classes 분류 테스크에서 적은 샘플로만으로 사람과 같은 정확도를 보여준다.
    • 기존 NTM과 달리 여기서의 memory access module은 content에 접근할 수 있음을 말한다. (memory location의 추가가 없음)
  • 논문의 접근법은 두 가지 best을 결합한 것이다.
    1. GD 학습을 통해 useful representation을 천천히 학습하는 것
    2. 외부 memory module을 통한 한 번의 presentation 후, 전에 본 적 없는 정보를 빠르게 습득
  • 이러한 조합은 meta-learning에 강력하고 많은 범위의 문제에 확장이 가능하다.
  • NTM의 논문을 읽은 적이 없어 블로그를 찾아서 봤지만 어렵... 

2. Meta-Learning Task Methodology

  • 보통 some dataset D에 대해서 cost L을 최소화하도록 parameters 을 학습한다.
  • 하지만 meta-learning에서는 parameters을 dataset p(D)의 분포에서 expected cost을 줄이는 방식으로 한다. (뭐가 다른거지? 기댓값이라는 개념?)
    • 일반적인 학습 작업과 매우 유사 해 보이지만 하나의 데이터 집합 이 하나의 데이터 샘플 로 간주된다. 
  • 이 논문의 셋업 task에서 다음과 같이 데이터세트가 있다고 하자.
    • Classification에서는 이미지 에 대한 가 class label이다. 
    • Regression에서는 real 값 에 대한 hidden function의 출력이 이다.
    • 즉 두 가지 task 모두 입력이 이고 가 target이며 temporally offset 방식으로 가 주어진다.
    • 이 뜻은  처럼 입력 sequence가 주어진다는 것이다.
    • (a)와 같이 입력이 구성이 된다는 것인데 현재 time-step에서의 입력에 이전의 time-step에서의 correct label이 같이 들어오게 되는 식이다. (마치 language model에서 teacher-forcing 느낌?)
    • 모델 그림을 쉽게 LSTM이라 생각하면 될듯
    • 여기서 중요한 점은 labels are shuffled from dataset-to-dataset, 
      • 예) class 5개에 대해 데이터를 10개씩 모은다.
      • 50 개의 데이터를 섞은다는 것 같음
    • 따라서 네트워크가 제대로 된 학습을 못하지만..? 다음 step에서 제대로 된 정답을 줌으로써 데이터 샘플의 메모리를 다음의 step까지 적절히 가져가야하는 점을 배운다.
    • 궁극적으로 시스템이 초점을 두는 모델링은 다음과 같다.
    • 각 time step에서의 
      • 보통 classification할 때, 에 해당하는 출력이 가 되도록 학습하지만,
      • 여기서는 t-1에 해당하는 데이터 집합에 대해 기억(메모리)을 가져간다는 것.
      • 이것이 가능한 이유는 correct label을 다음 step에서 제시하기 때문
      • 즉 figure 1의 (b)에서 t+1 step에서 에 해당하는 값이 2로 external 메모리로 가져간다.
      • 그리고 t+n step에서 때와 같은 class 입력이 들어왔을 때, t step에서 메모리에 write한 것(external memory)을 이용하면 된다 같은 개념?
    • 아직도 이해가 안가는 부분이 많으나 뒷 부분을 읽어보자...

3. Memory-Augmented Model

3.1 Neural Turing Machines

  • NTM에는 controller가 있는데 feed-forward network 혹은 LSTM와 같은 것을 말한다.
    • 많은 수의 read와 write heads을 이용한 이는 external memory와 상호작용을 한다.
  • Memory encoding과 NTM external memory 모듈에서 retrieval은 빠르다.
    • vector representation은 대체되거나 매 time-step에서 잠재적으로 제거한다.
  • 이 기능은 NTM을 meta-learning과 low-shot prediction에 완벽하게 만든다.
    • 느린 weights 업데이트를 통한 long-term storage
    • external memory 모듈을 통한 short-term storage
  • 만약 NTM이 메모리에 씌여질 representation에 대한 general strategy을 배우고 이를 prediction에 쓸 수 있다면, only 한 번 본 것에 대해서 정확한 예측을 수행할 수 있다.
    • verification 방식 처럼을 의미하는 건가?
  • Controllers은 앞에서 말했듯이 LSTM 또는 feed-forward networks이다.
    • 이것은 read와 write을 사용하여 external memory module과 상호작용을 한다.
    • read와 write는 각각 memory에서 representations을 찾거나 memory에 쓰는 것을 말한다.
  • 주어진 입력이 이고 이에 대해서 controller가 key 을 생성했다고 하자.
    • key는 memory matrix 에 저장되거나 particular memory, 예로 i번째 열을 찾는데 사용된다.
    • 메모리를 검색할 때, 는 cosine similarity 측정 방법을 통해 사용된다.
    • 즉 어떤 입력에 대해서 생성된 key(feature?)가 기존의 memory에 비슷한게 있으면 그것을 출력으로 하고 아니면 새롭게 memory에 쓴다는 것 같음.
    • 정확히는 다음과 같은 수식을 따른다.
    • 즉 식 (3)에서 key와 memory들과의 가중치를 구하고 의 가중치 합을 memory 으로 사용한다.
    • controller부터 사용된 메모리는 softmax output layer와 같이 classifier의 입력으로 사용된다.
    • 또한 추가적으로 next controller state을 위한 추가적인 입력으로도 사용된다.
    • 위 말은 LSTM에서의 hidden state vector처럼 출력을 위해서도 사용되고 다음 cell에도 입력으로 들어간다는 개념인 듯
    • 그런데 그러면 언제 memory에 write하는 것에 대해서는 3.2 절에서

3.2 Least Recently Used Access

  • NTM의 이전 instantiations은 content와 location 모두에 의해 결정되었다.
  • Location-based은 memory을 가로질러 long-distance을 넘나들 뿐 아니라 tape을 따라서 달리는 것과 비슷하다.
    • sequence-based 예측 task에서 유리한 방법이다.
    • 이 방법은 conjunctive coding of information independent of sequence에 적합하지 않다.
    • 쉽겟 seq2seq에서 쓰는 attention 기법은 sequential한 입력이기 때문에 location-based 방법을 써도 되는데, 여기서 입력은 sequential 한 개념이 아니기 때문에 LRUA을 사용한다는 것 같은데?
    • 따라서 LRUA을 제시한다.
  • LRUA는 content-based memory writer이다.
    • least used memory location (가장 적게 사용한 메모리 위치) 혹은 most recently used memory location (가장 최근 사용한 메모리 위치) 둘 중 한개를 write하는 것
  • This module emphasizes accurate encoding of relevant (i.e., recent) information, and pure content-based retrieval. 
  • 새로운 정보가 오면
    1. rarely-used location에 (즉 거의 안쓰던 위치) 인코딩된 정보를 보존 (업데이트를 하는 것)
    2. last used location에 (가장 최근 사용하던 위치) 메모리를 새로운 것으로 업데이트를 한다. (이 방법은 가장 연관성이 높을 위치이기 때문에)
    • 이 두가지 옵션은 previous read weights (이전의 read weights) 와 weights scaled according to usage weights (사용한 weight에 따라 scaled된)에 따라 interpolation이 된다.
  • 이러한 weight의 사용은 매  time-step에서 previous usage weights을 업데이트하고 current read와 write weight을 다음과 같이 추가한다.
    • u는 usage, r은 read, w는 write, 은 decay parameter이다.
    • 즉 이전 step t-1에서 usage한 weight에다가 read와 write할 weight을 더해주는 것이 현재 step t에서의 usage weight이다.
    • read weight는 3.1 절에서 식 (3)에서 설명을 했고 write weight는 밑에서 설명 (+이를 위한 weight_lu도 설명)
    • 개념은 사용한 weight이기 때문에 이전 step에서 decay을 취해주고 read, write weight을 사용했으니까 더해주는 느낌
  • least-used weights, 
    • 이는 time-step에서 을 계산하기 위해 사용이 된다.
    • 은 vector v에서 n번째로 작은 항을 말한다.
    • 즉 에서 n번째로 작은 항보다 가 크냐 작냐에 따라 0또는 1의 값을 가진다.
    • 그렇다면 n이라는 숫자는? memory을 읽은 숫자로 세팅한다. 무슨 의미지?
    • 즉 보다 작다는 것은 잘 사용하지 않았다는 것(least used)이고 이에 해당하는 값은 1로 두어서 write weight을 구할 때, 새롭게 쓰겠다! - 1) rarely-used location 방법
  • write weight 을 얻기 위해서 learnable sigmoid gate parameter을 사용하여 previous read weight와 leas-used weight을 convex combination한다.
    • 여기서 sigma가 sigmoid function이고 alpha는 scalar gate parameter로 interpolate을 결정하는 것이다.
    • 을 사용하는 것은 식 6에서도 설명했지만 이는 LRUA에서 새로운 정보가 왔을 때 write하는 위치를 결정하는 1)번 방법을 의미하고
    • 을 더해주는 것은 이전 step t-1에서 사용한 memory을 더해주는 것으로 2)번 방법 last used location을 의미한다.
  • write weight 을 이용하여 최종 write는 다음과 같다.
    • Thus, memories can be written into the zeroed memory slot or the previously used slot; if it is the latter, then the least used memories simply get erased.
    • 따라서 메모리는 1) zeroed memory slot에 write가 될 수도 2) 이전에 사용된 slot에 write일 수도 있는데 2)일 경우는 least used memory가 사라지는 형태이다.
  • 정리하자면)
    • 새로운 정보가 오면, 이를 이용하여 read도 하고 write도 한다.
    • 어디다 read를 하고 write을 하느냐? M이라는 matrix을 이용한다.
    • Matrix의 각 열들이 정보를 담고 있는 것인데..
    • read: 입력에 대해 key k을 만들어서 M 열들과 cosine softmax을 취해서 비슷한 만큼 열의 weighted sum으로 메모리를 읽는다.
    • write: write weight을 만들어서 M 각 열에 key k을 weighted 한 것을 더해준다.
      • 그렇다면 은 얼만큼 write 할까를 의미해야한다.
      • 그러면 적게 사용한 위치에 (least-used) 메모리를 새로쓰기도 해야하고 ()
      • 이미 있는 메모리를 업데이트를 하는 역할을 해야한다. ()
      • 이미 있는 메모리는 바로 직전 step에서 사용한 메모리의 weight 와 연관이 있을 것이라 생각하는 것이다.
        • 사실 개인적으론 이 부분이 의심쩍다.
        • 입력이 sequential하게 연관되어 있는 것이 아닌데 왜 꼭 연관성이 높은 거지?
        • 물론 메모리를 가져간다는 점에서 이렇게 해석을 할 수 있을 것 같긴 하다.

4. Experimental Results

4.1 Data

  • 두 가지 데이터 세트를 사용하였다.
    • Omniglot for classification
    • Gaussian process(GP)에서 샘플링한 데이터 for regression
  • Omniglot
    • 1600 class
    • class당 few examples
    • MNIST 급으로 유명
    • Overfitting을 줄이기 위하여 randomly translating and rotating
    • 60도, 180도, 270도 회전을 통하여 새로운 class을 만들었다.
    • 학습 데이터로는 1200개의 기존 클래스를 (+augmentation) 이용하였다.
    • 나머지 425 클래스는 (+augmentation) 테스트에 사용하였다.
      • 즉 학습은 1200+ / 테스트는 425+ 클래스를 사용했단 거겠지?
    • 이미지는 20x20으로 다운스케일하여 사용하였다.(계산 복잡성 때문)

4.2 Omniglot Classification

  • 그림 2
  • 학습 episode는 100,000개로 5개의 랜덤 선택된 class와 랜덤 선택된 labels이다.
  • 그리고 네트워크에 한 번도 본 적 없는 test episode가 주어진다.
  • 테스트 episode으로는 더 이상 학습이 일어나지 않고 한 번도 본 적 없는 class labels을 예측한다.
    • 높은 정확도를 보여준다. 82.8%
    • 5번째 instance는 94.9%까지 달성하고 10번째 instance는 98.1%까지 달성
    • instance가 머지??
  • 그리고 사람이 성능을 평가하는 방법에 대해서 설명되어 있으나 생략
논문 실험 부분을 읽다 보니 이해가 안되는 부분이 너무 많았다.
이를 다 이해하는 것은 비효율적이라 생각되어 시간이 오래 걸려 뒷 부분은 실험 결과만 첨부하고 생략해야겠다....
MANN은 external memory을 이용한 model-based의 가장 기초고 이 논문에서 파생된 스트림을 나중에 포스팅을 하면서 좀 더 자세히 이해해보자.

4.x Figures

Reference

댓글