NL-028, SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient (2017-AAAI)
0. Abstract
- 아직 언어는 GAN으로 생성하는데 한계가 존재
- 대표적으로 2가지 이유가 있음.
- G에서 생성된 출력이 discriminative model에서 generative model까지의 gradient update가 되기 어렵게 만듬.
- 또한 D model의 loss는 문장이 완성되어서야 평가하기 때문에, 중간 중간 sequence가 잘 생겨나는지에 대한 control이 안됨.
- 자세한 건 뒤에서 더...
- SeqGAN은 이러한 것을 해결하려고 함.
- RL의 stochastic policy을 이용하여 data generator을 하여 gradient policy update 문제를 해결하려고 함.
- stochastic을 사용하기 위해 policy gradient인 듯(value based가 아닌)
- RL에서 reward는 D model에서 나오는 값을 이용. Monte Carlo search 이용 등..
1. Introduction
- 쉽게 생각해보는(초창기) 방법은 RNNLM이다. Maximize log predictive likelihood을 하는 것
- 이 때 previous observed tokens을 가지고 LM 방식으로 하면 된다.
- 하지만 이 방법은 Inference시 exposure bias 문제가 발생한다.
- 모델이 sequence를 반복해서 생성하는 문제
- training과 inference의 간극으로 인해 발생되는 문제(token을 생성할 때, previous token을 조건으로 가져가야 하므로)
- 따라서 scheduled sampling(SS) 방법이 2015년에 제시되었다. (background 참고)
- 하지만 SS는 inconsistent training strategy이고 근본적으로 이 문제를 해결하는 것은 아니다.
- 다른 해결 방법으로는 전체 생성된 sequence에 해당하는 loss function을 설정하는 것이다.
- 즉 전체 생성된 문장이 잘 생성 되었는지에 대한 loss를 추가적으로..
- 예를 들어, 번역 문제에서 BLEU는 sequence generation을 guide에 해주는데 적용해볼 수 있다.
- 그러나 sequence loss 설정 방법은 다른 많은 practical application에서 적용하기 어렵다.
- 예를 들어, poem generation/chatbot 에서는 task-specific loss를 바로 적용하는게 불가능 할지 모른다.
- 아마 생성된 문장에 대한 loss를 설정하는게 애매해서 그런듯..? 번역같은 경우는 정답이란게 있어서 괜찮은 것 같으나 chatbot은 꼭 그런것은 아니니까..
- GAN을 이용해서 이러한 문제를 줄일 수 있다.
- exposure bias 문제를 없애기 위해선 전체 문장에 대한 평가가 필요한데 이를 D model로 G로 생성된 문장이 잘 생성 되었는지를 구별한다는 것이다. 이를 이용한 sequence loss 설정 하겠다는 것인 듯!
- 하지만 여기서 GAN 모델은 대표적으로 두 가지 문제가 등장한다.
- GAN은 real-valued continuous data를 생성하나 token은 discrete하다는 점.
- 여기서 말하길 G로 생성된 데이터에 관한 D loss의 gradient가 G모델을 변화를 거의 못시킨다는 점이다.
- 즉 G가 제대로 생성못해도 이를 deterministic transform을 해야한다. 왜냐하면 discrete token으로 만들어 줘야하니까.
- 실제 G에서 생성되는 continuous value를 어느 token으로 매칭을 시켜야하니까 continuous value가 크게 바뀌나 보다. 즉 G는 continuous space고 token들은 discrete space다 보니 이 차이가 학습에 지장을 주는 듯??...
- 코드를 보니 vocab size에(이 논문에선 5000) 맞는 one-hot encoding을 output으로 출력함.
- "나는 집에 간다" 라는 문장을 LM으로 학습한다 했을 때, "나는"(=[0,0,1] vector)이 입력이 들어가서 "집에"(=[0,1,0])라는 token이 나오려고 하면 [0,0,1]과 [0,1,0]사이의 모든 값은 존재하지 않는다는게 문제라는데..단어가 가지고 있는 공간은 총 3가지 뿐이기 때문.
- 그런데 NLG가 아닌 일반적인 RNNLM을 이용한 task을 할 때 생기는 문제점은 아닌가?
- 어차피 LM의 학습을 통하여 적용하는 task의 output space는 continuous하므로 괜찮은 듯?
- A major reason lies in that the discrete outputs from the generative model make it difficult to pass the gradient update from the discriminative model to the generative model.
- 내 식대로 이해해 보자면 즉 GAN에서 G모델을 학습시, D가 진짜라고 믿게끔 학습하는 것인데, 만약 G에서 "I are good"란 문장이 나왔다고 하자.
- D입장에서 이 놈이 진짜라고 믿게끔 해줘야하니까 "I am good"로 문장이 생성되도록 D에서 G에게 어떠한 신호를 보내줘야 한다.
- 하지만 are -> am은 discrete하게 멀찍이 떨어져있는 vector이다. (discrete token space)
- 따라서 이것에 대한 정보를 전달하기 힘들다는 문제점인 것 같다.
- one-hot vector이 들어가서 one-hot vector (next token)이 예측되어야 하는 상황임. 그런데 output이 one-hot vector로 discrete token space이므로 실제 예측 function은 discrete한 것이 나와야 한다.
- 따라서 이론적으로는 우리가 원하는 model G는 미분이 불가능할 것이다.
- 하지만 우리는 그냥 continuous G 모델로 one-hot vector을 가지도록 학습을 할 수는 있으나 이러면 괴리감이 있기 때문에 제대로 학습이 안된다는 점 같다.
- (추가부분) 언어 생성 모델을 짜고 직접 학습하면서 이부분을 이해했는데...
- decoder(generator)을 통하여 vocab에 있는 token들의 확률들이 나오게 된다.
- 이 확률을 가지고 argmax을 취해서 token을 결정해주고 이를 다시 word2vec을 취한후 D(discriminator)를 통과시켜 이용하는게 일반적이다.
- 여기서 argmax란 discrete한 function이기 때문에 gradient가 안흐르게 된다.
- 따라서 argmax을 안취하고 output 자체가 discrete한 space을 도출해야하는데 이게 힘들다는 것이고 이 부분이 이미지의 GAN과 다르다는 것이다.
- soft-embedding을 이용하여 해결할 수도 있는데 soft-embedding이란 그냥 확률 값을 이용하여 word2vec을 통과시키는 것이다. (어떻게 보면 weighted word2vec 같은 것)
- 즉 soft-embedding을 이용하면 continuous하기 때문에 학습이 가능하다.
- 이 논문에서는 policy gradient처럼 D의 결과를 reward식으로 이용하여 학습하는 것을 설명한다.
- input으로 word2vec 혹은 더 나아가서 BERT embedding을 넣으려고 해도 비슷한 문제점이 있다는 것을 생각해볼 수 있음!
- 그래서 Professor forcing 논문이 이러한 문제점을 지적하는거 같은데..
- 어찌되었든 우리가 학습할 모델이 결국 discrete space이면, BP 미분식으로 학습하는 continuous 식으로는 결국 제대로 된 학습이 안되는 것!?
- GAN에서 D model은 오직 entire sequence, 즉 문장이 다 생성되고나야 score/loss을 매길 수 있다는 점이다. 즉 문장이 잘 생성되고 있는지에 대해 partial sequence에서 feedback을 줄만한 score/loss가 없다는 점 같다.
- 지금까지 생각으론, RNNLM이 Generator 모델로 각 token에 대한 loss등을 설정하며 문장 생성을 할 수 있다. 하지만 이 때 exposure bias가 발생할 것이고 이를 해결하기 위해서 문장 전체에 대한 loss를 주는 과정이 필요하고 GAN을 이용해보자!
- 전체적인 문장에 대한 loss는 Discriminator 모델을 만들어 학습을 시킬 것이다. 이 때 D loss는 전체 문장에 대한 것이고 실제, 문장이 잘 생성되는지에 대한 feedback을 줄 중간 partial sequence에 대한 어떠한 설정이 없다는 점 같다. (즉 이미지하고 조금 다른게, 이미지는 마지막에 잘 생성되었는지 보면 되는데, 언어는 생성이 time series식으로 하나씩 token이 생성되는데 이에 대한 고려가 일반적인 GAN에서는 안한다는 말인 듯)
- 이 논문에서는 bahdanau(actor-critic 방법)을 기반으로 한 방식을 따른다.
- RL Agent: 생성 모델
- state: generated token
- action: next token to be generated
- reward: Feedback given to guide G by D on evaluating the generated sequence
- 방법은 Monte Carlo (MC) search 으로 state-action value 예측을 한다.
2. Related Work
- 생략
3. Sequence Generative Adversarial Nets
- 이 부분부터는 강화학습의 REINFORCEMENT 알고리즘을 이해하고 있어야 논문이 이해가 된다.
- 강화학습이 사실 개인적으로 쉽지 않다고 생각하는 분야라 완벽히 마스터하기는 ....
- 하지만 강화학습이 메인이 아니라 이 논문처럼 방법론으로 사용하는 정도라면 데이비드 실바 강의(한글로는 팡요랩)를 이해하면 되는 정도라고 생각함!
- Generator model
는
-parameterized 이고 sequence
을 생성하게 된다.
- Timestep t에서 state
을 말하며 current provided tokens을 말한다.
- 즉 t-1까지의 생성된 token까지 모아둔 sequence을 state라고 생각
- action a는 next token
to select을 말한다.
- 즉 t time에 어떤 token을 선택할지를 말하는 것이다.
- Policy model
은 stochastic이고 state transition은 deterministic이다.
- 이 말은 1:t-1까지 tokens까지 생성되어 있고, state
에서
를 생성한다는 것은 next state
은
로부터 state transition 확률이 1이다. 다른 state s''와 같은 것으로 갈 state transition 확률은 0이 된다.
- 쉽게 생각해서 state를 표현하는 것은 유일하므로(token들의 집합이 다르면 다른 state이니까) state transition이 deterministic 할 수 밖에 없다.
-parameterized discriminateive model
은
가 향상되도록 guidance를 주는 역할을 한다.
- Figure 1을 보면, 오른쪽이 SeqGAN에 대한 흐름인데 state에서 action을 해가며 끝까지 간다.
- 다음 챕터에서 나오는지 모르겠지만 여기서 언제 문장이 끝나는 지에 대한 언급은 없다.
- 즉 MC search로 문장들을 생성한 후, D로부터 reward를 받게 되는데 이 reward가 의미하는 것은 문장이 제대로 생성 됐는지 안됐는지에 대한 확률이라 보면 된다.
- 이 reward를 이용하여 policy gradient 방법으로 학습하게 되는 방식이다.
3.1 SeqGAN via Policy Gradient
- 강화학습에서 policy gradient에서 objective function을 episodic environment 환경에서
을 생각해볼 수 있다.
- 이를 기반으로 여기서도 objective function을 설정할 수 있다. G가 생성 모델이고 G는 다음의 obj func을 maximize을 하도록 학습을 시키는 개념이다.
- Q는 reward의 축적 값인데 이 reward는 discriminator
에서 부터 발생한다.
- 즉 문장이 완성된 sequence를 D가 진짜 문장인지 아닌 지로 binary classification 하는데, 이 때 진짜일 확률 값이 reward라고 생각하면 된다.
- Q(action-value) 함수는 REINFORCEMENT 알고리즘을 이용하여서 학습하였다.
- 즉 다시 말하면
(2) 식이 Q값을 가지고 이는 문장이 완성된 상태에서 발생하는 reward값이다.
- 그렇다면 중간 과정에서의 reward는??
- 이 식대로라면 중간 과정의 reward는 0으로 처리가 된다.
- 즉 final step만 고려한다면, previous token의 적합성만 고려하게 되는데, 이 뿐 아니라 중간 중간 생성과정에서 future 결과 또한 고려하는 것을 넣어줘야 한다.
- 이렇게 Q값을 정하기 위해서 Monte Carlo search with roll-out policy를 적용한다.
- Variance을 줄이고 Q의 정확한 평가를 높이기 위하여 rollout 사용함
- 쉽게 생각하면 현 상태, 현 policy로 미래를 탐색한 다음, 이로부터 reward로 받겠다는 것이다.
- MC로 탐색하기에 시간, 비용등 때문에 roll-out policy(간단한)을 적용하는 것이다.
- rollout 참고!
- 따라서 다음의 식(3) 같이 N-time MC search를 표현할 수 있다.
- 이는 y1~yt가 주어지고, yt+1~yT까지를 roll-out policy
를 따라서 N번 sample 하겠다는 것이다.
는 generator을 결국 의미하는 것이고 simple version을 사용하는데 뒤에서 어떤 모델인지 나옴.
- 따라서 이제 수정된 Q value는 다음 식(4) 와 같다.
- 이렇게
을 사용한 reward의 장점은, generative model을 반복적으로 향상 시킬 수 있다는 것이다. 또한 G 모델이 제대로 된 문장을 생성하고 나서는 D모델을 re-train을 다음과 같이 하는 것이 좋다.(즉 GAN 방식을 적용할 수 있다)
- 최종적으로 obj function (1)은 다음과 같이 (6)식으로 표현할 수 있다.
- 맨 앞의 sigma 부분은 MC search를 담당
- expect value는 정의 + condition Y 설정
- 뒷 괄호 부분은 매 sample에 대한 obj function이라고 보면 됨.
- 이 (6)식은 밑의 (7)처럼 유도가 된다.
- 따라서 식 (8)로 업데이트를 하면 됨. alpha는 learning rate이다.
- 이러한 과정을 다합쳐서...전체적인 학습은 다음과 같다.
- 여기서 초기화는 MLE로 G를 학습하고 이 G에대한 D를 학습하는 식으로 pre-training을 한다.
- 이렇게 pre-training 하는 것이 효과적임을 실험적으로 알았다고 함.
- Dataset S: positive, Generate data: negeative sample로 사용했다고 하고, D를 학습할 때, balance를 고려하여 d-steps = #generative negative samples = #positive examples 가 되도록 설정하였다고 함. (d-step마다 negative sample을 생성하고 positive sample을 하나 sampling 하는 듯?)
- 여기서 generator policy를 roll-out policy에 대입을 하는데 이는 큰 의미는 없는 것 같음.
- 단지 학습 시 generator로 중간까지 생성된 문장을 고정하고 rollout으로 그 뒤의 문장을 생성 후 reward를 받는데 사용하는 식임
3.2 The Generative Model for Sequences
- Generative model로 LSTM을 사용함.(GRU도 사용가능 함)
- softmax output layer을 bias vector c와 weight matrix V를 이용하여 추가함.
3.3 The Discriminative Model for Sequences
- Text classification 모델로는 CNN / RCNN / RNN 등 여러 방법이 있지만 여기서는 Zhang and LeCUN 2015 방법을 썼다고 한다.
- CNN 구조를 가지며, appendix나 논문을 찾으면 자세히 설명이 나오는데 (Kim 방법이랑 똑같은거 같은데..) 아주 간단히만 적어보면
- token embedding 된 것을 concat을 시킨다.
- convolution을 태운다. filter size는 여러 개로
- 나온 값을 max over time pooling을 시킨다.
- highway network을 태운다.
- 마지막으로 FC을 통하여 cross entropy 방식으로 학습.
4. Synthetic Data Experiments
- 실험을 하기 위해서 synthetic data을 만들었다고 함.
- Oracle로 불리는 랜덤 초기화 LSTM을 true model로 사용했다고 함.
- 즉 진짜 real data로 실험한다는 것이 아니고 "랜덤 초기화 LSTM"을 oracle로 생각하고 이게 만드는 데이터가 진짜 data라고 생각하자.
- 그러면 이 데이터로 seqGAN을 학습하여서 seqGAN이 oracle을 잘 모방한다면 제대로 학습이 된다는 의미로 볼 수 있다.
4.1 Evaluation Metric
- 내가 이해한 바로는 학습된 generative model G가 문장을 생성한다.
- 생성: 예시 문장1) 나는 논문 보느라 목이 아프다 / 예시 문장2) 나는 책을 읽어서 배가 고프다.
- 이것을 oracle(=사람한테) "나는"(t=1)이란 token을 보여줘서 그 다음 token을 만들라고 한다. "논문"이 나올 확률을 NLL로 구한다.
- "나는 논문"을 보여주고(t=2) "보느라"가 나올 NLL을 구한다.
- "나는 논문 보느라"을 보여주고(t-3) "목이"가 나올 NLL을 구한다.
- "나는 논문 보느라 목이"을 보여주고 "아프다"가 나올 NLL을 구한다.
- 이런식으로 나온 NLL을 다 더한다.
- 상대적으로 여기서 예시 문장1은 NLL이 낮게 나올 것이고 예시 문장2는 NLL이 높게 나올 것이다.
- 이렇게 예시 문장들에 대한 expect value을 구한다.(즉 다 더해서 개수를 나누는 방식으로 할 듯)
- 여기서 이해를 위해 oracle을 사람으로 생각하고 real data를 가지고 seqGAN을 학습시켰다고 생각하는 것이랑 개념이 비슷한 것 같다.
- 실제로는 랜덤 초기화 LSTM이 oracle이라고 한다. 랜덤 LSTM으로 next token이 생성되는 것이 real data라고 가정한다는 것이다.
- 이 데이터를 가지고 seqGAN을 학습시키겠다는 것이다.
- 그러면 oracle이 real distribution이고 학습된 seqGAN이 oracle을 잘 모방하면 학습모델이 real distribution을 따른 다는 것이고 수식상 NLL_oracle이 작아진 다는 것이다.
4.2 Training Setting
- Normal distribution N(0,1)을 따르는 LSTM을 oracle로 real data을 만든다.
- 10,000 sequences of length 20으로 training set S을 만든다.
- Dataset S는 label을 1로 generated examples은 label 0으로 구성한다.
- 1. Random token generation, 2. MLE trained LSTM, 3. SS, 4. PG-BLEU와 비교를 하였다.
4.3 Results
5. Real-world Scenarios
- 그렇다면 진짜 데이터세트로 실험해보자.(poem composition, speech language generation, music generation)
5.1 Text Generation
- Chinese poems와 Barack obama political speeches 두 개에 대해 실험함.
- Generation texts와 human-created texts 사이의 BLEU score로 유사성을 구함.
- BLEU score은 원래 번역 품질을 위해 만들어진 방법이라고 함.
- 중국 시는 n-gram(2)인 BLEU-2로 평가를 하였는데 이는 중국 시가 한 개 혹은 두 개의 characters로 구성이 되어 있다고 함.
- 오바마 연설문은 비슷한 이유로 BLEU-3과 BLEU-4로 평가했다고 함.
- 중국 시 같은 경우, MLE와 SeqGAN, real data 각각 20개씩 총 60개의 시를 70명의 중국 시 전문가에 대해 평가를 받았다.
- 진짜 시 같으면 +1, 아니면 0점을 줘서 scoring 하였다.
- 오바마 연설문에 대해서는 왜 안했는지 아쉽.. 아마 지네들이 보기에 제대로 생성으로 보기에는 무리가 있어서 그랬을 듯.
5.2 Music Generation
- Nottingham datasets을 사용하였다고 함.
- BLEU와 MSE로 평가를 함. (위에 Table 4)
- 이걸로 생성된 음악이 어떨지 개인적으로 궁금..
- 사실 음악 생성은 단지 패턴을 기억해서 섞어둔 정도가 아닐까 싶은데
6. Conlcusion
- GAN으로 discrete token의 시퀀스를 생성한 첫 번째 시도이다.
- Oracle evaluation mechanism 사용
- MCTS와 value network을 사용해서 발전 가능.(알파고 논문 처럼 말하는 듯)
- 추가적으로 appendix에 pre-training의 유무에 따른 실험이 있는데 pre-training을 안하면 좀 처럼 성능이 좋아지지 않는다. 즉 MLE로 pre-training 하는 것이 중요한 팁인듯!
Reference
댓글
댓글 쓰기