KL divergence

KL divergence는 사실 엄청 많이들 접해봤을 것이라고 생각한다. 
기본적으로 ML에서 두 확률 분포의 차이를 계산해야하는 task들에서 많이 나온다. 
대표적으로 GAN 등에서 개인적으로는 많이 봤던 것 같다. 
팡요랩에서 다룬 부분이 이때 까지 보았던 설명자료중 가장 잘 정리해 둔 것 같아서 그것을 토대로 간단히 정리해 보려고 함. (팡요랩은 PRML 책을 기반으로 정리하였고 시간이 되면 동영상을 직접 보는거 추천)
  • 예시로 설명을 하자.
  • 1번상황) A 사람이 B 사람에게 오늘의 해가 뜨는 방향을 알려준다.
    • (실제로는 100% 동쪽에서 뜨겠지만) 문제 설정을 0.99999 확률로 동쪽, 0.00001 확률로 서쪽에서 뜬다고 하자.
    • 이 때 B 사람은 동쪽에서 뜬다는 말을 듣고 별로 놀라지 않을 것이다.(당연하니까)
    • 즉 이 때 놀란 정도를 "정보량"이라고 볼 수 있다.
    • 수식으로 정리하면 확률 변수 x는 east or west
    • p(x)는 x의 확률 분포
    • h는 p(x)에 대한 정보량으로써 h=f(p)라는 f 함수를 따름.
    • f는 이때 단조 감소 함수일 것이다.
  • 2번상황) A 사람이 B 사람에게 오늘의 날씨 정보도 같이 알려준다.
    • y는 비가 오는지 안오는지에 대한 확률 변수
    • 이 때 x,y는 독립변수라고 생각
    • 우리는 h(x,y) = h(x) + h(y) (정보량은 단순 합으로 표현되길 원함)
    • p(x,y) = p(x)p(y) (독립변수라고 가정했기에)
    • 이를 만족시키는 함수 f를 찾자!
    • 즉 f(p(x,y))=f(p(x))+f(p(y)), f(p(x,y))=f(p(x)p(y)) 이기 때문에 f(p(x)p(y)) = f(p(x))+f(p(y))을 만족시켜야함!
    • 따라서 f는 log함수로 정의 가능!
    •  로 표현 가능(log의 밑은 2 외에도 아무거나 상관 없음)
  • 1번 상황으로 돌아가서 평균적인 정보량은 어떻게 될까?
    • h(east)=-log2(p(east))=0.000014, h(west)=-log2(p(west))=16.609
    • 즉 0.00018이 평균 정보량(기댓값)이 된다.
    • 이 값을 entropy라고 부른다.
    • Entropy 용어는 열역학에서(고등학생 화학 수업을 떠올리자..) 자유도를 의미하는 말로 등장하는데 그것과 같은 개념이라고 하는 듯
    • 용어 그대로 entropy가 크면 자유도가 큰 것 즉 정보량이 큼을 의미한다.
  • 3번상황) 주사위 예시로 정팔면체인데 8개면이 나올확률이 (1/2, 1/4, 1/8, 1/16, 1/64, 1/64, 1/64, 1/64) 라고 하면 이때의 entropy = 2가 나온다.
    • 이때 이 8개의 면은 0, 10,, 110, 1110, 111100, 111101, 111110, 111111로 encoding할 수 있다.
    • average code length=2 가나온다.(단순계산)
    • 즉 여기서 말하고자 하는 것은 entropy가 최소한으로 정보를 encoding 했을 때의 평균 길이와 같다는 것이다. 
    • 따라서 entorpy는 average code length의 under bound라고 한다.(증명은 되어있다고 함)
    • 즉 결론은 average code length 개념으로 접근할 수 있다고 함.
  • Entropy 특징
  • 4번상황) 사면체에서 각 사면체가 나올 확률은 (1/4, 1/4, 1/4, 1/4) = p(x)
    • 하지만 우리가 학습한 모델(q=(x))은 (1/2, 1/4, 1/8, 1/8)로 잘못 예측
    • 이 때 모델의 평균 entropy==-1/4x(log2(1/2)+log2(1/4)+log2(1/8)+log2(1/8))=2.25 (이 값이 cross entropy임. 밑에 다시 언급)
    • 하지만 실제 평균 entropy= =-4x1/4xlog2(1/4)=2
    • 이 둘의 차이는 2.25-2=0.25 만큼 추가적인 비용이 발생했다고 본다!
  • KL divergence의 특징
    • KL은 정확한 거리의 개념은 아님!
    • 여기서 두 번째 증명은 KL(p|q)에서 미분이용해서 p=q일때 최소임을 보이고 두 번 미분해서 convex임을 보여주면 될 것 같음.
  • Cross-entropy
    • Cross-entropy는 KL divergence의 앞 쪽 term을 의미
      • 다시 말하면 KL divergence = cross entropy - entropy
    • q를 p에 가깝게 학습할 때, 뒤 쪽 term은 의미가 없다.
  • 결론
    • 기본적으로 entropy는 정보량의 기댓값을 의미함. (정보량은 -log로 정의)
    • KL divergence 개념은 정보량을 진짜 data 분포의 정보량과 모델 분포의 정보량의 차이임.
    • 즉 KL divergence를 작게 만드는 것은 진짜 data 확률 분포에 모델의 분포가 가까워지는 것을 의미하기 때문에 classification task에서 적용함
      • 하지만 KL에서 뒷 term은 p에 대해 고정이기 때문에 결국엔 Cross-entropy loss을 씀!
      • 따라서 딥러닝에서 cross-entorpy loss을 쓰는 것은 사실 KL divergence loss랑 같은 개념임.
        • 그렇다면 (전체 term을 포함한) KL divergence는 언제 쓰나?
        • p가 바뀌는 task에서 쓴다고 하는데 GAN에서도 나오고 다른 강화학습에서 나온다고 함..
      • regression 문제에서는 출력이 0~1인 확률이 아니기 때문에 L2, L1 loss등을 쓴다.
      • 근데 classification에서 MSE L2 loss을 쓰면 안되나?
        • One class와 multi class 나눠서 생각해보자
        • One class인 경우, 마지막 출력이 0~1인 확률이 나와야 하므로 logit에 sigmoid을 취해야 한다.
          • 그렇게 되면 BP을 할 때 수식적으로 학습이 잘 안되는 문제가 있음.
          • sigmoid 때문에 미분하면 실제 출력이 1이 나와야 할 때, 0이 나온 경우 학습이 잘 안이뤄짐(saturation 영역에서 BP가 잘 전달이 안되는데 자세한 건 딥러닝 북 같은 거 참조하거나 직접 계산해보면 됨)
        • Multi class인 경우, 마지막 출력을 softmax로 0~1로 만들어 준다.
          • 이 경우에도 softmax를 미분하면 z(1-z) 식으로 미분 값이 나오기 때문에 비슷한 현상이 발생한다. 
        • 이것 외에도 학습 시, error에 대한 정의가 잘 되어 있어야 하는데 cross entropy가 정보량에 대한 전달을 잘 해준다는 점에서 일반적으로 많이 쓰인다.
        • 하지만 MSE를 쓴다고 학습이 안되는 건 아니다. softmax + MSE로 학습해도 MNIST 경우 잘만되었던 기억이..
        • 추가적으로 multilabel classification 경우는, softmax로 처리할 수 없기 때문에 sigmoid_cross_entropy를 씀
    • 또한 P, Q를 바꿔서 사용하는 것은 어떨까?
      • 일반적인 위에서 언급한 것은 Forward KL이라고 부르고 P,Q 순서를 바꿔서 loss 정의하는 것은 Reverse KL라고 부른다고 한다.
      • Reverse KL이 유리할 때도 존재하는데.. 이는 Ref2 영상을 참고하시길!

Reference

댓글