본문 바로가기

Enginius/Machine Learning

Learning Wake-Sleep Recurrent Attention Models

따끈따끈한 논문이다. http://arxiv.org/pdf/1509.06812v1.pdf 

Abstract


CNN이 좋긴 하지만 이미지의 전 영역에서 연산을 해야한다는 단점이 있다. 그래서 randomness를 섞은 stochastic attention-based model이 test 단계에서 시간을 줄여준다는 장점이 있다. 하지만 학습이 어렵다는 단점이 있다. 이 논문에서는 training deep generative model에서 제안된 내용들을 이용해서 wake-sleep recurrent attention model을 제안한다. Posterior inference 성능을 향상시켰고, stochastic gradient의 variability를 줄였다. 


Introduction


CNN은 성능이 깡패다. 하지만 고화질 이미지에 시간이 오래 걸린다. 왜냐하면 전 영역에 테스트를 해야하기 때문이다. 그래서 visual-attention model이 나왔다. 일반적으로 이런 attention model은 두 가지로 나뉠 수 있다: soft and hard attention based models. 


Soft attention based model은 전 이미지 영역에서 weighted average를 통해서 feature를 뽑고, weight를 saliency map을 통해서 구한다. 이에 반해서 hard attention model은 몇 개의 discrete glimpse location들을 고른다. 따라서 soft attention model은 연산이 오래걸린다. 우리는 연산을 줄이기 위해선 hard attention model이 중요하다고 생각한다. 물론 hard model은 학습이 어렵다는 단점이 있다. 


Stochastic hard attention model을 학습시키는 것은 어렵다. 왜냐하면 gradient가 intractable posterior expectation을 포함하기 때문이다. (당연하다. selection은 미분가능하지 않다.) 그래서 이 연구에선 Wake-Sleep Recurrent Attention Model (WS-RAM)을 제안한다. 이 알고리즘은 stochastic recurrent attention model을 학습하는데 사용된다. 


학습 과정에서 WS-RAM은 posterior expectation을 importance sampling을 통해서 구하고, proposal distribution은 inference network를 통해서 구한다. Prediction network와는 다르게 inference network는 물체의 label 정보에 접근이 가능하고, 이는 glimpse location을 더 잘 고르게해준다. 


이 논문의 주된 contribution은 stochastic attention model을 학습시키는 새로운 방법을 제안하고, 기존의 variational inference를 이용한 방법과 비교했다. 두 번째는 novel control variate technique for gradient estimation을 제안한다. 마지막으로는 우리의 stochastic attention model이 translated되있고, scale된 MNIST digit을 구분할 수 있고, 이미지 캡션을 만들어낼 수 있다는 것을 보였다. 


Wake-Sleep Recurrent Attention Model


이 논문의 가장 중요한 부분인 Wake-Sleep Recurrent Attention Model (WS-RAM)을 설명하겠다. 


이미지 '$I$' 가 주어지면 네트워크는 먼저 sequence of glimpses '$\mathbf{a}=(a_1, ..., a_N)$'을 고른다. 그리고 각 glimpse 후에 observation '$\mathbf{x}_n$'을 '$g(a_n, I)$'의 mapping을 통해서 구한다. 이 매핑은 예를 들어 해당 위치에서 이미지 패치를 뽑는 것을 의미한다 첫 번째 glimpse는 입력의 low-resolution version을 이용해서 구하고, subsequent glimpses는 이전 glimpse에서 구해진 정보를 이용해서 구한다. '$p(a_n|a_{1:n-1}m \mathbf{I}, \theta)$' 를 이용해서 확률적으로 구해지고, '$\theta$'는 네트워크의 파라미터를 의미한다. 이러한 방식은 이미지 전 영역을 보는 soft attention과 대비된다. 마지막 glimpse 후에 네트워크는 타겟 '$y$'의 분포 '$p(y|\mathbf{a}, \mathbf{I}, \theta)$'를 구한다. 예를 들어, 타겟은 이미지 캡션이나 이미지 category를 나타낸다. 


그림 1에 나타나있듯이, attention network의 core는 두 계층의 recurrent network로 우리는 "prediction network"라고 부른다. 각 시간의 출력은 action으로 다음 단계의 input을 계산하는데 사용된다. 가장 간단하게 생각할 수 있는 것은 이미지 패치의 위치를 옮기는 것이다. (최근에 권인소 교수님 연구실에서 나온 CNN을 이용한 attention model과 상당히 비슷하다. 사실 attention model이 다 비슷비슷하다.) 저화질의 input image가 네트워크에 주어지면 sequence of glimpses가 얻어지고, 마지막 glimpse를 통해서 class label을 구한다. 특히 저화질의 입력이 두 번째 layer로 주어지고, class label prediction은 첫 번째 layer를 통해서 구해진다. 이는 low-resolution layer와 class prediction 사이의 dependency를 끊는다. (좋은 아이디어다. 저화질 이미지는 대략적인 윤곽을 주고, 실제 action은 고화질의 이미지에서 크롭된 데이터를 사용하려는 것 같다.) 


Prediction network의 윗단에는 inference network가 있다. 이는 입력으로 class label과 attention network의 가장 상위 representation을 받는다. 이 네트워크는 posterior distribution의 근사 '$q(a_{n+1}|y, a_{1:n}, \mathbf{I}, \eta)$'를 의미하고, '$\eta$'로 parametrize되어있고, image category에 condition되어있다. 이는 posterior sampler로 사용되며, attention network의 선생 역할을 한다


Stochastic attention model의 다른 benefit은 mapping '$g$'가 이미지의 small region으로 localize될 수 있다는 점이고, 이는 전체 알고리즘을 매우 효율적으로 만들어준다. 게다가 '$g$'는 미분가능할 필요가 없어서 많은 연산을 가능하게 한다. 


이 뒤는 학습에 대한 얘기이다. Detail은 생략하겠다. 


Over-egging the pudding


이 논문의 가장 주된 contribution은 inference network를 통해서 proposal distribution을 만들고, 이를 통해 학습을 효과적으로 할 수 있게 하는 것이다. (importance sampling) 즉 학습을 효과적으로 하는 방법론인 것 같다.