Enginius/Python&TensorFlow

TF에서 gradient 정보를 바로 활용하기

해리s 2020. 3. 21. 15:13

만약 TF computational graph에 custom function과 custom gradient를 넣고 싶으면 어떻게 해야할까?

아래 gist에 해당 내용을 정리해보았다.

https://gist.github.com/sjchoi86/01d8957cfdeb39d55e7dd42e8836b2ab

 

결국 TF 문법이 어려웠던 것인데,

 

  • 기본적인 tf.py_func은 gradient가 None이 되어서, 만약 forward path만 신경을 쓰고 싶을 때 사용하면 편한 것 같아요. 특히나 graph에 있는 tf Tensor를 sess.run 등으로 뺴오지 않아고 py_func으로 처리를 하면 해당 함수 안에서는 ndarray가 되거든요.
  • 만약에 gradient도 custom으로 주고 싶다면 tf.RegisterGradient를 해주고 (unique한 이름으로), gradient_override_map으로 gradient를 알려줘야해요. 아래 gist에서는 grad_wrapper가 그 역할을 해줍니다. 기본적으로 grad_wrapper의 op와 grads는 tf Tensor가 되므로 여기 안에서 tf.py_func을 써서 numpy로 정의한 custom gradient function을 call하게 됩니다. (custom_func_derivatives)

아직도 어렵다.