NL-051, Distilling Task-Specific Knowledge from BERT into Simple Neural Networks (2019-Arxiv)
■ Comment
Reference
- 이 논문도 NL-50논문과 같이 preprint 이기는 하지만, distill을 시도한 첫 번째 논문으로 여겨진다.
- 아마 어딘가에 제출했다가 reject을 당하였겠지만 오픈리뷰가 아니여서 확인해 볼 수는 없다.
- 보통 distilling은 기존의 모델만큼 깊은 구조는 아니지만 같은 Transformer의 구조를 띄는 모델에 하는 것이 일반적으로 들었었는데 이 논문은 새로운 구조에 한 것이 인상 깊었다.
- 논문에서 언급하는 정도의 성능 정도로만 distilling 된다면, 서비스에서는 시간이 오래걸리는 구조가 아닌 LSTM을 써도 괜찮을 것 같다는 생각이 들었다.
- NL-50보다 먼저 나온 논문으로 BERT 경량화에서 처음으로 나온 논문으로 보여진다.
*Distilling
1.Task
data로 Fine-tuned
BERT을 student에
distilling을
한다.
2.이
때는 data labels와
teacher의
predicted
prob을
동시에 이용한다.
0. Abstract
- NLP에서 BERT, ELMo, GPT와 같이 deep language representation 계열의 연구가 트렌드이다.
- 이것들은 shallower neural network가 NLP에서 이제는 큰 의미가 없다고 이끌고 있다.
- 하지만, 기초적인 가벼운 네트워크도 external training, additional input features을 이용하면 여전히 경쟁령있다고 한다. (아키텍쳐 구조의 변함없어도)
- 여기서 하고자하는 것은 BERT을 single-layer BiLSTM 구조에 distill knowledge하는 것이다.
- paraphrasing, NLI, sentiment classification의 여러 개의 데이터세트에서 ELMo와 비교할만한 성능을 보여주고 100배의 적은 parameters와 15배의 빠른 inference time을 기록한다.
1. Introduction
- NLP에서 기존의 구조를 이용하여 성능이 증가하고 있다. (복잡도와 깊이도 깊어지고)
- 이러한 과정 중, first-generation, 즉 초기의 뉴럴 네트워크는 무시되고 있다.
- 표면적으로 pre-trained deep word representation은 다양한 task에서 SoTA을 보여주고 있다.
- 최근의 연구로는 BERT와 GPT-2가 많은 데이터를 이용하여 큰 transformer 모델을 가지며 SoTA들을 보여준다.
- 이러한 큰 네트워크들은 그러나 실제로 사용할 때는 문제가 있다.
- 만흥ㄴ parameters로 인하여 자원적인 제한이 있고 모바일 시스템과 같은 곳에 적용하기가 어렵다.
- 즉 inference 시간과 효율적인 면에서 실제 시스템에서는 사용하기 어렵다.
- 이 논문에서는 효과적인 task-specific knowledge을 하기 위한 접근법을 소개한다.
- 특별히 얇은 구조인 BiLSTM에 지식을 전달하게 되고 두 가지 motivation을 가지고 한다.
- 1) we question whether a simple architecture actually lacks representation power for text modeling
- 2) we wish to study effective approaches to transfer knowledge from BERT to a BiLSTM.
- 즉, knowledge distillation 접근법을 레버리지를 함으로써 큰 모델을 teacher로 간주하고 작은 모델이 teacher을 모방하는 student로 간주한다.
- knowledge transfer을 효과적으로 사용하기 위하여 크지만 unlabled 데이터세트를 필요로한다.
- 즉 teacher model이 이러한 unannotated samples에 대해 labels을 예측한다.
- CV 분야에서는 이러한 unlabled image을 rotation, additive noise와 같은 것으로 얻기가 쉽지만 NLP의 specific task에서는 얻기가 어렵다.
- 전통적인 NLP의 agumentation은 task-epcific을 하게 되는 것이며 다른 NLP task에 적용하기 어렵다.
- 따라서 여기서 새로운 rule-based textual data agumentation 접근법을 소개한다.
- 실제 이러한 samples이 fluent하지는 않지만, 실험적으로 성능을 좋게함을 보여준다.
- 예전부터 knowledge transfer 분야에서 궁금한 것이, 정상 데이터만을 student가 teacher을 모방하는 방법이 아니라 noise 데이터라도 모방하게 하면 결국에는 모델의 분포가 비슷하게 따라가는 것이 아닌가 하는 의문점이다.
- 즉 이러한 데이터가 당연히 효과가 있을 것이라 생각만 했는데 이런 것을 입증한 논문이 따로 있을까?
- 여기서는 3개의 task에 대해 실험을 한다.
- Sentence classification
- Sentence matching
- knowledge distillation 과정은 중요하고 기존의 simpler network만을 따로 하는 것보다 좋은 성능을 보여준다.
- 이것이 처음으로 distilling을 시도하였으며 얇은 BiLSTM 구조가 ELMo과 비교할만한 수준을 보여준다.
- 심지어 100배 파라미터수가 적고 15배 빠르다고 한다.
2. Related Work
- 간단하게만 살펴보면
- 언어 처리에서의 이 논문과 관련있는 이전의 연구들
- CNN 계열 연구들
- Kalchbrenner et al., 2014
- Kim, 2014
- RNN 계열 연구들
- Mikolov et al., 2010, 2011
- Graves, 2013
- Recursive NN
- Socher et al., 2010, 2011
- Sentence classification
- Zhang et al., 2015
- Conneau et al., 2016
- Sentence matching
- Wan et al., 2016
- He et al., 2016
- Pre-trained Model
- Peters 2018, ELMo
- Devlin 2018, BERT
- Model Compression 계열 연구들
- 고전 방법
- LeCun 1990
- Han 2015
- 하지만 이러한 방법은 weight sparsity와 같은 일반적이지 않은 상황인, 많은 optimized 계산이 필요한 상황이 아닌 것에 대한 연구이다.
- Pruning entire filters
- Li 2016, Liu 2017
- device-centric metric을 타게팅하는 것이다.
- Quantizing neural networks
- Wu 2018
- Courbariaux 2016
- 네트워크를 binarize하는 것으로 binary weight 혹은 binary activation을 하게 된다.
- Knowledge distillation
- Ba and Caruana 2014
- Hinton 2015
- 이 논문에서 쓰는 방법으로 teacher라 불리는 큰 네트워크를 student라 불리는 작은 네트워크로 knowledge transfer하는 개념이다.
- NMT(번역) task에서 Kim and Rush 2016 연구가 있었고 LM에서 Yu 2018 연구가 있었다.
3. Our Approach
- 여기서는 크게 두 가지 요소가 있다.
- 1) logits-regression objective
- 2) transfer 데이터세트 구성 (데이터 증강으로 효과적인 지식 전달을 위해)
3.1 Model Architecture
- Teacher network은 pre-trained BERT을 fine-tuned시킨 것을 의미한다.
- BERT을 finetune 시킬 때는 크게 특이한 점은 없다.
- single 문장에 대해서는
- BERT에 문장을 태워서 나온 feature에 fc을 통과시켜 학습을 시킨다.
- 논문에서는 언급을 하지는 않지만 표기를 봤을 때는 feature은 [CLS] token을 쓰는 것 같다.
- sentence-pair에 대해서는
- 두 문장의 BERT feature을 concat 시킨 후, single 문장처리하는 것과 똑같이 진행한다.
- finetune할 때는 BERT 파라미터도 학습을 하는 방향으로 진행했다.
- 위 그림들은 student network BiLSTM의 모델 구조이다.
- Fig 1은 하나의 문장을 처리하는 task에 대한 그림이고 Fig 2는 두 문장(pair) task에 대해 처리하는 구조이다.
- 위 그림에서 설명이 없는 점으로는 concatenate-compare에 대한 것인데 이에 대한 정의는 다음과 같다.
- 이 외에 다른 제약이나 추가적인 trick은 없다고 한다. (attention 혹은 layer normalization과 같은)
3.2 Distillation Objective
- Ba and Caruana (2014)에서 one-hot predicted label외에도 teacher가 예측한 확률을 모방하는 것도 중요하다고 한다.
- binary 문장 분류에서는 몇몇 문장들은 강한 감정을 가지는 중립적인 성향을 띄는 문장들도 있다.
- 만약 우리가 teacher가 예측한 one-hot label만으로 student을 학습한다면, 예측 uncertainty에 대한가치있는 정보를 잃어버리게 된다.
- discrete 확률은 다음과 같이 정의된다.
- logits에 대해 학습하는 것은 student 모델이 teacher에서 targets 예측에 중요하다고 생각하는 관계에 대해 배우기 때문에 학습을 더 쉽게 만든다.
- 이에대한 distillation objective은 MSE로 다음과 같다.
- 다른 측정법으로는 soft targets의 cross entropy이다.
- 그러나 실험적으로 MSE가 좀 더 낳은 성능을 보여준다.
- 학습할 때 distilling objective는 다음과 같이 one-hot label t에 대한 전통적인 cross-entropy을 따른다.
- t는 one-hot target으로 ground-truth label을 의미한다.
- 만약 unlabeled dataset을 사용할 때는, teacher가 predicted label을 사용하는데 이면 =1이고 그 외에는 =0이다.
- 즉 teacher의 logit이 가장 큰 class로 gt를 대체하여 학습한다는 것
- 정리하면)
- 앞의 term은 GT label로 cross-entropy (~NLL)로 학습을 의미하는데 unlabeld data라 GT label이 없으면 teacher가 판단한 것으로 GT을 대체하는 것
- 두 번째 term은 teacher의 logit과 student의 logit이 같도록 MSE loss을 이용한다는 것이다.
3.3 Data Augmentation for Distillation
- 작은 데이터세트에서는 teacher model이 fully하게 지식을 전달하는게 충분하지 않다.(Ba and Caruana, 2014)
- 따라서 큰 데이터세트가 필요하고 unlabeled 데이터세트인 경우는 pseudo-labels을 teacher로부터 제공을 한다.
- NLP는 비전 영역과 다르게 단순히 rotation과 같은 것으로 데이터 증강시키게되면 문장이 유창하지 않고 깨끗하지도 않다.
- 여기서는 휴리스틱한 증강법을 제안한다.
- Masking
- POS-guided word replacement
- token을 확률을 이용해서 같은 POS tag을 가지는 다른 단어로 치환한다.
- 기존 문장의 분포를 유지하기 위해 새로운 단어는 unigram word distribution re-normalized by the part-of-speech (POS) tag에서 샘플링 한다.
- What do pigs eat? → How do pigs eat?
- n-gram sampling
- 예제에서 확률로 n-gram을 랜덤으로 샘플한다.
- n은 1부터 5사이의 숫자고 이것은 word dropping을 하는 것으로 좀 더 공격적인 masking과 같은 것이다.
- 즉 하나의 문구를 없앤다는 개념과 같은거니까
- 학습 예시가 와 같을 때, 반복적으로 을 한다.
- 좀 더 자세히 위의 것을 설명하면)
- 만약 이면 wi을 masking한다.
- 만약 이면 POS-guided word replacement을 적용한다.
- 즉 masking과 POS-guided는 상호 배타적인 것으로 룰이 하나만 적용이 된다.
- 단어들을 iterating을 한 후, 확률로 n-gram sampling으로 전체 합성 예제를 만든다.
- 이렇게 만들어진 최정 예제가 agumented가 되고 unlabeled dataset으로 append 된다.
- 한 예제에 대해 위의 절차를 만큼만 반복하여 생성한다.
- 만약 sentence pair 데이터같은 경우는 (second 문장을 고정한 후) first 문장을 증강시키거나 그 반대로 하게 된다.
4. Experimental Setup
- BERT(Large)을 teacher 모델로 사용하고 Adam optimizer을 사용하고 lr은 {2,3,4,5}ㅌ10^-5으로 학습하면서 validation 에서 가장 좋은 모델을 고른다.
- fine-tuning할 대는 data agumentation을 피한다.
- distilled BiLSTM을 학습은 3.2 section에서 alpha가 0을 선택하였다.
- 0을 선택했다면 왜 alpha을 논문에서 언급한 것이지?
- alpha에 따른 실험결과도 있으면 좋을 것 같다고 생각함.
4.1 Datasets
- GLUE 데이터세트를 이용하였으며 다음과 같은 task가 있다.
- SST-2
- MNLI
- QQP
4.2 Hyperparameters
- BiLSTM: 150혹은 300 hidden units
- ReLU: 200 or 400 units
- word2vec: 300 차원
- AdaDelta로 학습
- ==0.1, =0.2, =0.25
4.3 Baseline Models
- BERT
- OpenAI GPT
- GLUE ELMo baselines
- 위의 3가지 모델 다 pre-trained 모델로 유명한 것들이다.
5. Results and Discussion
- 위의 결과를 보면, non-distilled BiLSTM은 BERT의 ELMo baseline 보다 성능이 좋지 않다.
- training data을 증강시켜서 Distillation 접근법을 적용하면, BiLSTM 성능이 증가하고 ELMo보다 성능이 좋아진다.
- Distilled model은 이전의 ELMo(row 4, 5)에 비해 경쟁력이 있으며 shallow BiLSTM들은 이전에 생각했던 것보다 더욱 좋은 representation을 보여준다.
- 기존의 deep transformer models보다 성능이 4~7 point정도 떨어지지만, 훨씬 적은 paramters와 좋은 효율성을 보여주는데 다음의 섹션에서 설명해준다.
5.2 Inference Efficiency
- 파이토치를 쓰고 V100 한장을 써서 실험하였다.
- batch는 512, 67350 문장의 SST 학습 데이터를 사용
- Table 2를 보면, BiLSTM(SOFT)가 훨씬적은 parameters와 inference 시간이 빠르다는 것을 알 수가 있다.
6. Conclusion and Future Work
- BiLSTM 기본 모델에 BERT을 knowledge distilling을 탐구하였다.
- 파라미터의 수는 적고 인퍼런스 시간은 적게들면서 ELMo에 비교할만한 성능을 달성하였다.
- 이 결과로 얇은 BiLSTMs들도 natural language tasks들에 이전보다 더욱 인상적인 결과를 보여줄 수 있음을 확인한다.
- Future work로는 더욱 간단한 구조인 CNN와 같은 구조, SVM, logistic regression 등에 적용을 해보는 것이다.
- 다른 방향으로는 조금 더 복잡한 구조인 word interaction과 attention과 같은 tricks을 사용한 모델에 적용해보는 것이다.
댓글
댓글 쓰기