[Tensorflow]Tensorflow 2.0建立模型的方法

Original article can be found here (source): Deep Learning on Medium

[Tensorflow]Tensorflow 2.0建立模型的方法

在Tensorflow 2.0的大改版後,寫Tensorflow像是變了一個新的框架。雖然TF2大致上支援TF1的語法,但從網路上公開的code中可以發現,以往在TF1中建立graph的方法都漸漸改變了。我認為這一切的原因都來自整個variable命名與操作的設計變化,因為Python原生function與variable與Tensorflow的融合,我們不得不使用一些新的方法來管理Tensorflow的運作。

而這個新的方法,就是Keras物件。

這篇文章會介紹Keras物件在TF2中的地位,以及如何在TF2中透過tf.GradientTape建立可訓練的模型。

關於Tensorflow 2.0的大略改動與方向,可以參考上篇文章。

不甘屈就選項的Keras

稍微看一下時下流行的TF2開源程式,不難發現Keras這個關鍵字頻繁地出現在不同專案上。雖然大部分的專案不見得完整的使用了Keras的訓練流程,但大多都會使用了Keras.modelKeras.layers.Layer

變數的管理

這一切的改變,主要來自Tensorflow 2.0在variable命名與操作的設計變化。

Keep track of your variables! If you lose track of a tf.Variable, it gets garbage collected. (from Tensorflow official website)

白話點說,請自行保管好所有的tf.Variable,弄丟了恕不奉還。在這樣的劇本中,大部分的開發者只好摸著鼻子,乖乖地遵從官方建議的方法:使用Keras.modelKeras.layers.Layer來管理變數。

Keras物件的三大法寶

我認為Keras Functional API裡有三大法寶,是目前寫TF2的人會頻繁使用到的物件:keras.modelkeras.layers.Layerkeras.Input

Layer

顧名思義keras.layers.Layer就是一個客製化的類神經layer,沒什麼特別的,一般來說寫法與Pytorch感覺很像,都是在__init__定義變數或運作元,在call裡才定義連接。

weights: 6
trainable weights: 6

這個簡單的例子建立了一個客製化的Keras.layers.Layer,說穿了就是三層全連接層。值得注意的是這邊的weighstrainable_weights都可以直接取用(長度為6是三個Dense的weights跟bias),這就是為什麼會使用Keras物件來管理變數的原因。

Model & Input

keras.Model則是將一個或多個運算元、Keras.layers.Layer堆疊起來管理,並且提供統一的進出口。透過自行連接好的個個運算元,只要將輸入與輸出定義好,就可以透過keras.Model打包。在使用時,輸入則是透過keras.Input實現。keras.Input是有形狀、並支援None長度的接口,實際訓練時再透過這個接口送資料進來,其實跟placeholder的感覺差不多。

Input: tf.Tensor( 
[[[0 1 2 3 4]]
[[5 6 7 8 9]]], shape=(2, 1, 5), dtype=int32)
Output: tf.Tensor(
[[1.0517771] [3.3455203]], shape=(2, 1), dtype=float32)

keras.Model一樣支援trainable_variables等方法,一樣可以達到控管變數的目的。而keras.Input實際操作時可以支援keyword來實現multiple input,使用上感覺等價於TF1在session.runfeed_dict

現在常見的寫法

現在TF2中最常見的寫法是將重複性的原件透過自定義的subclassed layer寫好後,再用keras.Model打包。最後使用tf.GradientTape訓練。在inference時直接調用打包好的keras.Model

重複使用一下上面範例中的MLPBlock,這邊在打包成一個MLPModel。唯一的差別是多了inputs跟outputs的控管。訓練時在外面呼叫tf.GradientTape紀錄梯度,用GradientTape.gradient依照目標變數取出梯度後,再送到optimizer優化目標變數。

losses: tf.Tensor(1.1403029, shape=(), dtype=float32)
losses: tf.Tensor(0.06633704, shape=(), dtype=float32)
losses: tf.Tensor(0.0027843604, shape=(), dtype=float32)
losses: tf.Tensor(8.474709e-05, shape=(), dtype=float32)
losses: tf.Tensor(2.3672121e-06, shape=(), dtype=float32)

小結論

以上就是目前最常見、方便的Tensorflow 2.0建立模型的方法。當然還有走原本Model.compile的完整Keras functional API,不過實際上受限制較多,還是比較推薦用自由度高的tf.GradientTape

從以上的範例中不難發現,Keras物件跟Tensorflow 2.0是緊密結合的,而且寫法、變化多端,下一篇文章會介紹一些進階的Keras物件混搭實作,以及分析Keras functional API與subclassed layer兩者間的優缺點。

對於TF2寫法想要認真進修的,我相當推薦看一下YOLO v3的實作與我同事Peter寫的ESRGAN