NL-028, SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient (2017-AAAI)

0. Abstract

  • 아직 언어는 GAN으로 생성하는데 한계가 존재
  • 대표적으로 2가지 이유가 있음.
    1. G에서 생성된 출력이 discriminative model에서 generative model까지의 gradient update가 되기 어렵게 만듬.
    2. 또한 D model의 loss는 문장이 완성되어서야 평가하기 때문에, 중간 중간 sequence가 잘 생겨나는지에 대한 control이 안됨.
    • 자세한 건 뒤에서 더...
    • SeqGAN은 이러한 것을 해결하려고 함.
  • RL의 stochastic policy을 이용하여 data generator을 하여 gradient policy update 문제를 해결하려고 함.
    • stochastic을 사용하기 위해 policy gradient인 듯(value based가 아닌)
  • RL에서 reward는 D model에서 나오는 값을 이용. Monte Carlo search 이용 등..

1. Introduction

  • 쉽게 생각해보는(초창기) 방법은 RNNLM이다. Maximize log predictive likelihood을 하는 것
  • 이 때 previous observed tokens을 가지고 LM 방식으로 하면 된다.
  • 하지만 이 방법은 Inference시 exposure bias 문제가 발생한다.
    • 모델이 sequence를 반복해서 생성하는 문제
    • training과 inference의 간극으로 인해 발생되는 문제(token을 생성할 때, previous token을 조건으로 가져가야 하므로)
  • 따라서 scheduled sampling(SS) 방법이 2015년에 제시되었다. (background 참고)
  • 하지만 SS는 inconsistent training strategy이고 근본적으로 이 문제를 해결하는 것은 아니다.
  • 다른 해결 방법으로는 전체 생성된 sequence에 해당하는 loss function을 설정하는 것이다.
    • 즉 전체 생성된 문장이 잘 생성 되었는지에 대한 loss를 추가적으로..
    • 예를 들어, 번역 문제에서 BLEU는 sequence generation을 guide에 해주는데 적용해볼 수 있다.
  • 그러나 sequence loss 설정 방법은 다른 많은 practical application에서 적용하기 어렵다.
    • 예를 들어, poem generation/chatbot 에서는 task-specific loss를 바로 적용하는게 불가능 할지 모른다.
    • 아마 생성된 문장에 대한 loss를 설정하는게 애매해서 그런듯..? 번역같은 경우는 정답이란게 있어서 괜찮은 것 같으나 chatbot은 꼭 그런것은 아니니까..
  • GAN을 이용해서 이러한 문제를 줄일 수 있다.
    • exposure bias 문제를 없애기 위해선 전체 문장에 대한 평가가 필요한데 이를 D model로 G로 생성된 문장이 잘 생성 되었는지를 구별한다는 것이다. 이를 이용한 sequence loss 설정 하겠다는 것인 듯!
  • 하지만 여기서 GAN 모델은 대표적으로 두 가지 문제가 등장한다.
    1. GAN은 real-valued continuous data를 생성하나 token은 discrete하다는 점.
      • 여기서 말하길 G로 생성된 데이터에 관한 D loss의 gradient가 G모델을 변화를 거의 못시킨다는 점이다.
      • 즉 G가 제대로 생성못해도 이를 deterministic transform을 해야한다. 왜냐하면 discrete token으로 만들어 줘야하니까.
      • 실제 G에서 생성되는 continuous value를 어느 token으로 매칭을 시켜야하니까 continuous value가 크게 바뀌나 보다. 즉 G는 continuous space고 token들은 discrete space다 보니 이 차이가 학습에 지장을 주는 듯??...
      • 코드를 보니 vocab size에(이 논문에선 5000) 맞는 one-hot encoding을 output으로 출력함.
      • "나는 집에 간다" 라는 문장을 LM으로 학습한다 했을 때, "나는"(=[0,0,1] vector)이 입력이 들어가서 "집에"(=[0,1,0])라는 token이 나오려고 하면 [0,0,1]과 [0,1,0]사이의 모든 값은 존재하지 않는다는게 문제라는데..단어가 가지고 있는 공간은 총 3가지 뿐이기 때문.
        • 그런데 NLG가 아닌 일반적인 RNNLM을 이용한 task을 할 때 생기는 문제점은 아닌가?
          • 어차피 LM의 학습을 통하여 적용하는 task의 output space는 continuous하므로 괜찮은 듯?
        • A major reason lies in that the discrete outputs from the generative model make it difficult to pass the gradient update from the discriminative model to the generative model.
          • 내 식대로 이해해 보자면 즉 GAN에서 G모델을 학습시, D가 진짜라고 믿게끔 학습하는 것인데, 만약 G에서 "I are good"란 문장이 나왔다고 하자.
          • D입장에서 이 놈이 진짜라고 믿게끔 해줘야하니까 "I am good"로 문장이 생성되도록 D에서 G에게 어떠한 신호를 보내줘야 한다.
          • 하지만 are -> am은 discrete하게 멀찍이 떨어져있는 vector이다. (discrete token space)
          • 따라서 이것에 대한 정보를 전달하기 힘들다는 문제점인 것 같다.
          • one-hot vector이 들어가서 one-hot vector (next token)이 예측되어야 하는 상황임. 그런데 output이 one-hot vector로 discrete token space이므로 실제 예측 function은 discrete한 것이 나와야 한다.
          • 따라서 이론적으로는 우리가 원하는 model G는 미분이 불가능할 것이다.
          • 하지만 우리는 그냥 continuous G 모델로 one-hot vector을 가지도록 학습을 할 수는 있으나 이러면 괴리감이 있기 때문에 제대로 학습이 안된다는 점 같다.
        • (추가부분) 언어 생성 모델을 짜고 직접 학습하면서 이부분을 이해했는데...
          • decoder(generator)을 통하여 vocab에 있는 token들의 확률들이 나오게 된다.
          • 이 확률을 가지고 argmax을 취해서 token을 결정해주고 이를 다시 word2vec을 취한후 D(discriminator)를 통과시켜 이용하는게 일반적이다.
          • 여기서 argmax란 discrete한 function이기 때문에 gradient가 안흐르게 된다.
          • 따라서 argmax을 안취하고 output 자체가 discrete한 space을 도출해야하는데 이게 힘들다는 것이고 이 부분이 이미지의 GAN과 다르다는 것이다.
          • soft-embedding을 이용하여 해결할 수도 있는데 soft-embedding이란 그냥 확률 값을 이용하여 word2vec을 통과시키는 것이다. (어떻게 보면 weighted word2vec 같은 것)
          • 즉 soft-embedding을 이용하면 continuous하기 때문에 학습이 가능하다.
          • 이 논문에서는 policy gradient처럼 D의 결과를 reward식으로 이용하여 학습하는 것을 설명한다.
          • input으로 word2vec 혹은 더 나아가서 BERT embedding을 넣으려고 해도 비슷한 문제점이 있다는 것을 생각해볼 수 있음!
        • 그래서 Professor forcing 논문이 이러한 문제점을 지적하는거 같은데..
      • 어찌되었든 우리가 학습할 모델이 결국 discrete space이면, BP 미분식으로 학습하는 continuous 식으로는 결국 제대로 된 학습이 안되는 것!?
    2. GAN에서 D model은 오직 entire sequence, 즉 문장이 다 생성되고나야 score/loss을 매길 수 있다는 점이다. 즉 문장이 잘 생성되고 있는지에 대해 partial sequence에서 feedback을 줄만한 score/loss가 없다는 점 같다.
      • 지금까지 생각으론, RNNLM이 Generator 모델로 각 token에 대한 loss등을 설정하며 문장 생성을 할 수 있다. 하지만 이 때 exposure bias가 발생할 것이고 이를 해결하기 위해서 문장 전체에 대한 loss를 주는 과정이 필요하고 GAN을 이용해보자!
      • 전체적인 문장에 대한 loss는 Discriminator 모델을 만들어 학습을 시킬 것이다. 이 때 D loss는 전체 문장에 대한 것이고 실제, 문장이 잘 생성되는지에 대한 feedback을 줄 중간 partial sequence에 대한 어떠한 설정이 없다는 점 같다. (즉 이미지하고 조금 다른게, 이미지는 마지막에 잘 생성되었는지 보면 되는데, 언어는 생성이 time series식으로 하나씩 token이 생성되는데 이에 대한 고려가 일반적인 GAN에서는 안한다는 말인 듯)
  • 이 논문에서는 bahdanau(actor-critic 방법)을 기반으로 한 방식을 따른다.
  • RL Agent: 생성 모델
  • state: generated token
  • action: next token to be generated
  • reward: Feedback given to guide G by D on evaluating the generated sequence
  • 방법은 Monte Carlo (MC) search 으로 state-action value 예측을 한다.

2. Related Work

  • 생략

3. Sequence Generative Adversarial Nets

  • 이 부분부터는 강화학습의 REINFORCEMENT 알고리즘을 이해하고 있어야 논문이 이해가 된다.
    • 강화학습이 사실 개인적으로 쉽지 않다고 생각하는 분야라 완벽히 마스터하기는 ....
    • 하지만 강화학습이 메인이 아니라 이 논문처럼 방법론으로 사용하는 정도라면 데이비드 실바 강의(한글로는 팡요랩)를 이해하면 되는 정도라고 생각함!
  • Generator model -parameterized 이고 sequence 을 생성하게 된다.
  • Timestep t에서 state 을 말하며 current provided tokens을 말한다.
    • 즉 t-1까지의 생성된 token까지 모아둔 sequence을 state라고 생각
  • action a는 next token  to select을 말한다.
    • 즉 t time에 어떤 token을 선택할지를 말하는 것이다.
  • Policy model 은 stochastic이고 state transition은 deterministic이다.
    • 이 말은 1:t-1까지 tokens까지 생성되어 있고, state 에서 를 생성한다는 것은 next state  은 로부터 state transition 확률이 1이다. 다른 state s''와 같은 것으로 갈 state transition 확률은 0이 된다.
    • 쉽게 생각해서 state를 표현하는 것은 유일하므로(token들의 집합이 다르면 다른 state이니까) state transition이 deterministic 할 수 밖에 없다. 
  • -parameterized discriminateive model 은 가 향상되도록 guidance를 주는 역할을 한다.
    • 는 가 real 문장인지 아닌지에 대한 확률을 뱉어주는데, 이게 G한테 더 잘 만드라는 것을 말해준 다는 뜻인 듯
  • Figure 1을 보면, 오른쪽이 SeqGAN에 대한 흐름인데 state에서 action을 해가며 끝까지 간다. 
    • 다음 챕터에서 나오는지 모르겠지만 여기서 언제 문장이 끝나는 지에 대한 언급은 없다.
  • 즉 MC search로 문장들을 생성한 후, D로부터 reward를 받게 되는데 이 reward가 의미하는 것은 문장이 제대로 생성 됐는지 안됐는지에 대한 확률이라 보면 된다.
  • 이 reward를 이용하여 policy gradient 방법으로 학습하게 되는 방식이다.

3.1 SeqGAN via Policy Gradient

  • 강화학습에서 policy gradient에서 objective function을 episodic environment 환경에서 을 생각해볼 수 있다.
  • 이를 기반으로 여기서도 objective function을 설정할 수 있다. G가 생성 모델이고 G는 다음의 obj func을 maximize을 하도록 학습을 시키는 개념이다.
    • Q는 reward의 축적 값인데 이 reward는 discriminator  에서 부터 발생한다.
    • 즉 문장이 완성된 sequence를 D가 진짜 문장인지 아닌 지로 binary classification 하는데, 이 때 진짜일 확률 값이 reward라고 생각하면 된다.
    • Q(action-value) 함수는 REINFORCEMENT 알고리즘을 이용하여서 학습하였다.
    • 즉 다시 말하면 
      (2) 식이 Q값을 가지고 이는 문장이 완성된 상태에서 발생하는 reward값이다.
    • 그렇다면 중간 과정에서의 reward는??
      • 이 식대로라면 중간 과정의 reward는 0으로 처리가 된다.
      • 즉 final step만 고려한다면, previous token의 적합성만 고려하게 되는데, 이 뿐 아니라 중간 중간 생성과정에서 future 결과 또한 고려하는 것을 넣어줘야 한다.
    • 이렇게 Q값을 정하기 위해서 Monte Carlo search with roll-out policy를 적용한다.
      • Variance을 줄이고 Q의 정확한 평가를 높이기 위하여 rollout 사용함
      • 쉽게 생각하면 현 상태, 현 policy로 미래를 탐색한 다음, 이로부터 reward로 받겠다는 것이다.
      • MC로 탐색하기에 시간, 비용등 때문에 roll-out policy(간단한)을 적용하는 것이다.
      • rollout 참고!
    • 따라서 다음의 식(3) 같이 N-time MC search를 표현할 수 있다.
      • 이는 y1~yt가 주어지고, yt+1~yT까지를 roll-out policy 를 따라서 N번 sample 하겠다는 것이다.
      • 는 generator을 결국 의미하는 것이고 simple version을 사용하는데 뒤에서 어떤 모델인지 나옴.
    • 따라서 이제 수정된 Q value는 다음 식(4) 와 같다.
      • 이것 또한 중간 과정에서의 reward는 없이 완성된 문장을 D로 판별한 reward가 곧 Q이지만, N-time MC search로 중간 과정을 고려한 것이라고 판단 됨.
    • 이렇게 을 사용한 reward의 장점은, generative model을 반복적으로 향상 시킬 수 있다는 것이다. 또한 G 모델이 제대로 된 문장을 생성하고 나서는 D모델을 re-train을 다음과 같이 하는 것이 좋다.(즉 GAN 방식을 적용할 수 있다)
      • 식(5)는 D에 해당하는 파라미터를 업데이트하는 것임.
  • 최종적으로 obj function (1)은 다음과 같이 (6)식으로 표현할 수 있다.
    • 맨 앞의 sigma 부분은 MC search를 담당
    • expect value는 정의 + condition Y 설정
    • 뒷 괄호 부분은 매 sample에 대한 obj function이라고 보면 됨.
  • 이 (6)식은 밑의 (7)처럼 유도가 된다.
    • 유도 과정은 다음의 부록을 따른다.
    • 이 부분은 강화 학습을 알면 어렵지 않은 부분이니 하나하나 자세히 따라가서 이해될 것임
  • 따라서 식 (8)로 업데이트를 하면 됨. alpha는 learning rate이다.
    • GD 알고리즘은 Adam과 RMSprop 등을 썼다고 함.
  • 이러한 과정을 다합쳐서...전체적인 학습은 다음과 같다.
    • 여기서 초기화는 MLE로 G를 학습하고 이 G에대한 D를 학습하는 식으로 pre-training을 한다.
    • 이렇게 pre-training 하는 것이 효과적임을 실험적으로 알았다고 함.
    • Dataset S: positive, Generate data: negeative sample로 사용했다고 하고, D를 학습할 때, balance를 고려하여 d-steps = #generative negative samples = #positive examples 가 되도록 설정하였다고 함. (d-step마다 negative sample을 생성하고 positive sample을 하나 sampling 하는 듯?)
    • 여기서 generator policy를 roll-out policy에 대입을 하는데 이는 큰 의미는 없는 것 같음.
    • 단지 학습 시 generator로 중간까지 생성된 문장을 고정하고 rollout으로 그 뒤의 문장을 생성 후 reward를 받는데 사용하는 식임

3.2 The Generative Model for Sequences

  • Generative model로 LSTM을 사용함.(GRU도 사용가능 함)
  • softmax output layer을 bias vector c와 weight matrix V를 이용하여 추가함.

3.3 The Discriminative Model for Sequences

  • Text classification 모델로는 CNN / RCNN / RNN 등 여러 방법이 있지만 여기서는 Zhang and LeCUN 2015 방법을 썼다고 한다.
  • CNN 구조를 가지며, appendix나 논문을 찾으면 자세히 설명이 나오는데 (Kim 방법이랑 똑같은거 같은데..) 아주 간단히만 적어보면
    1. token embedding 된 것을 concat을 시킨다.
    2. convolution을 태운다. filter size는 여러 개로
    3. 나온 값을 max over time pooling을 시킨다.
    4. highway network을 태운다.
    5. 마지막으로 FC을 통하여 cross entropy 방식으로 학습.

4. Synthetic Data Experiments

  • 실험을 하기 위해서 synthetic data을 만들었다고 함.
  • Oracle로 불리는 랜덤 초기화 LSTM을 true model로 사용했다고 함.
    • 즉 진짜 real data로 실험한다는 것이 아니고 "랜덤 초기화 LSTM"을 oracle로 생각하고 이게 만드는 데이터가 진짜 data라고 생각하자.
    • 그러면 이 데이터로 seqGAN을 학습하여서 seqGAN이 oracle을 잘 모방한다면 제대로 학습이 된다는 의미로 볼 수 있다.

4.1 Evaluation Metric

  • oracle을 사용하면 두 가지 이점이 있다.
    1. 학습 데이터세트를 제공함.(데이터를 만들 필요가 없다는 것임)
    2. 정확한 성능을 측정할 수 있다.
  • 식은 다음과 같다.
  1. 내가 이해한 바로는 학습된 generative model G가 문장을 생성한다.
  2. 생성: 예시 문장1) 나는 논문 보느라 목이 아프다 / 예시 문장2) 나는 책을 읽어서 배가 고프다.
  3. 이것을 oracle(=사람한테) "나는"(t=1)이란 token을 보여줘서 그 다음 token을 만들라고 한다. "논문"이 나올 확률을 NLL로 구한다.
  4. "나는 논문"을 보여주고(t=2) "보느라"가 나올 NLL을 구한다.
  5. "나는 논문 보느라"을 보여주고(t-3) "목이"가 나올 NLL을 구한다.
  6. "나는 논문 보느라 목이"을 보여주고 "아프다"가 나올 NLL을 구한다.
  7. 이런식으로 나온 NLL을 다 더한다.
  8. 상대적으로 여기서 예시 문장1은 NLL이 낮게 나올 것이고 예시 문장2는 NLL이 높게 나올 것이다.
  9. 이렇게 예시 문장들에 대한 expect value을 구한다.(즉 다 더해서 개수를 나누는 방식으로 할 듯)
  • 여기서 이해를 위해 oracle을 사람으로 생각하고 real data를 가지고 seqGAN을 학습시켰다고 생각하는 것이랑 개념이 비슷한 것 같다.
  • 실제로는 랜덤 초기화 LSTM이 oracle이라고 한다. 랜덤 LSTM으로 next token이 생성되는 것이 real data라고 가정한다는 것이다.
  • 이 데이터를 가지고 seqGAN을 학습시키겠다는 것이다.
  • 그러면 oracle이 real distribution이고 학습된 seqGAN이 oracle을 잘 모방하면 학습모델이 real distribution을 따른 다는 것이고 수식상 NLL_oracle이 작아진 다는 것이다.

4.2 Training Setting

  • Normal distribution N(0,1)을 따르는 LSTM을 oracle로 real data을 만든다.
  • 10,000 sequences of length 20으로 training set S을 만든다.
  • Dataset S는 label을 1로 generated examples은 label 0으로 구성한다.
  • 1. Random token generation, 2. MLE trained LSTM, 3. SS, 4. PG-BLEU와 비교를 하였다.

4.3 Results

4.4 Discussion

  • D-step과 G-step에 따라 SeqGAN의 안정성이 결정이 된다.
  • 실험적으로 G-step이 크면 빠르게 수렴은 되는데 D가 제대로 수렴이 안되어 있기 때문에 나중에 학습이 잘 안되는 불안전성을 가지는 것을 알 수 있었다.
  • 실험에서 알 수 있듯이 d-step이 g-step 이상이어야 seqGAN이 안정적이다.
    • 생각해보면 D가 곧 reward인데 reward 자체가 불완전하면 G또한 제대로 학습이 안될 거 같긴 함..
    • 원래 영상 GAN에서도 이러한 문제가 있을 듯? 일반적으로 GAN은 D를 더 많이 학습하는 거 같음.

5. Real-world Scenarios

  • 그렇다면 진짜 데이터세트로 실험해보자.(poem composition, speech language generation, music generation)

5.1 Text Generation

  • Chinese poems와 Barack obama political speeches 두 개에 대해 실험함.
  • Generation texts와 human-created texts 사이의 BLEU score로 유사성을 구함.
  • BLEU score은 원래 번역 품질을 위해 만들어진 방법이라고 함.
  • 중국 시는 n-gram(2)인 BLEU-2로 평가를 하였는데 이는 중국 시가 한 개 혹은 두 개의 characters로 구성이 되어 있다고 함.
  • 오바마 연설문은 비슷한 이유로 BLEU-3과 BLEU-4로 평가했다고 함.
  • 중국 시 같은 경우, MLE와 SeqGAN, real data 각각 20개씩 총 60개의 시를 70명의 중국 시 전문가에 대해 평가를 받았다.
    • 진짜 시 같으면 +1, 아니면 0점을 줘서 scoring 하였다.
    • 오바마 연설문에 대해서는 왜 안했는지 아쉽.. 아마 지네들이 보기에 제대로 생성으로 보기에는 무리가 있어서 그랬을 듯.

5.2 Music Generation

  • Nottingham datasets을 사용하였다고 함.
  • BLEU와 MSE로 평가를 함. (위에 Table 4)
  • 이걸로 생성된 음악이 어떨지 개인적으로 궁금..
  • 사실 음악 생성은 단지 패턴을 기억해서 섞어둔 정도가 아닐까 싶은데

6. Conlcusion

  • GAN으로 discrete token의 시퀀스를 생성한 첫 번째 시도이다.
  • Oracle evaluation mechanism 사용
  • MCTS와 value network을 사용해서 발전 가능.(알파고 논문 처럼 말하는 듯)
  • 추가적으로 appendix에 pre-training의 유무에 따른  실험이 있는데 pre-training을 안하면 좀 처럼 성능이 좋아지지 않는다. 즉 MLE로 pre-training 하는 것이 중요한 팁인듯!
Reference

댓글