[機器學習 ML NOTE]Generative Adversarial Network, GAN 生成對抗網路



我覺得GAN是近期Deep learning裡面最有趣也是最厲害的一個技術,他可以自動生成training data,甚至data跟真實的資料非常的相似,甚至可以做出圖像特徵的轉變(畫風改變,人臉改變等技術),這些生成的資料便可以拿去training model,也可以生成一些自己的所想要的資料,接下來我就來介紹一下GAN的原理,下面我也會用Tensorflow把GAN給實作出來,生成Mnist手寫資料!!

Generative Adversarial Network 生成對抗網路

“GAN!!這也太厲害了吧!!!” GAN的出來讓我們可以很大聲的說髒話了(誤

GAN是2014年的一個大神 Ian Goodfellow 提出來的方法,我用簡單一點的話來表達什麼是GAN,在GAN組織裡面中有二個角色,一個是專門偽造假名畫來去賣的G先生,一個是專門鑑定此符畫是不是真畫的D先生,D先生會從G先生那邊拿到假畫來辨斷真假,G先生則是利用D先生的鑑定來改良自己製造假畫的技術,G先生跟D先生互相共同合作,GAN! 這跟本要大賺了。

我們來看下面的圖,你能相信這些寢室圖都是GAN所製造出來的假圖嗎?我是看不太出來啦,GAN這個組織跟本可怕!!

圖片來自 “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks” https://arxiv.org/abs/1511.06434v2

我們可以從前面的話整理出GAN的簡單Flow架構

GAN Basic Flow Chart

上面的Flow是一個最基本的GAN的Flow,可以看出GAN中有二個Neural Network需要去Train,接下來我就來介紹Discriminator跟Generator這二個神經網路應該要怎麼去訓練。

Discriminator Network(鑑別器網路)

鑑別器網路簡單一點說明的就是,訓練出一個Neural Network可以分辨偽造出來的圖跟真實的圖,那要怎麼訓練這個網路呢?我自己畫了用以下的圖來理解。

Discriminator NN training

沒錯,就是很直觀的,我們把Generator出來的圖標記為0(fake image),然後把真實的圖標記為1,這樣的training data 丟進我們的Discriminator Network做訓練,這就是每一次Discriminator訓練的步驟了,接下來我們來看Generator怎麼做,然後再把二個連結再一起。

Generator Network (生成器網路)

生成器網路的概念也很簡單,就是要訓練出一個Neural Network可以讓Discriminator分辨出來的結果越接近真實(1)的結果越好,我畫了以下的圖來理解生成器網路的訓練方式。

Generator NN training

從以上的圖中就可以知道,我們可以把Generator+ Discriminator看成是一個大的Neural Network,假設生成器是前5層的Neural,鑑定器是後5層的Neural,然後這個10層的Neural Network 預估出來的值越接近1越好,但是這裡我們只update Generator 的weight,Discriminator的Weight要Keep,這樣才可以用更新後 Generator出來的假圖在Discriminator上的值越接近真實的結果,我們可以用一個結論來表示,其實就是更新生成器的參數讓Discriminator接近真實的結果。

GAN (Generative Adversarial Network)

我們可以從上面看出生成器跟鑑定器的參數要怎麼update,下面是論文中GAN的演算法

依照演算法上的Loss function,我統整了整個GAN的流程,希望可以更容易了解GAN的基本運作。

GAN Training Flow Chart

G_Loss就是我們把生成器產生的偽造圖輸入進鑑定器中的輸出跟1的loss

D_Loss就是我們把生成器產生的偽造圖輸入進鑑定器中的輸出跟0的loss+真實的圖輸入進鑑定器中的輸出跟1的loss。

整個GAN的架構大概就是這樣,接下來就來實作啦!!!!!!

MNIST 手寫數字 GAN TENSORFLOW 實作

下面的CODE就是我用MNIST資料來實作GAN

https://github.com/super13579/tensorflow-GAN-MNIST/blob/master/GAN_MNIST.pyscript

Generator (生成器)

為4層網路,都為fully connect layer,input為1*100的random noise data,這裡要注意的是最後一層的activate function是用tanh

Discriminator (鑑別器)

鑑別器也是4層網路,input為攤平的28*28的image data,也就是1*784的array,最後output layer出來的是鑑別圖片是真是否的預估值,每一層最後都會有個dropout防止overfitting的狀況 (想了解什麼是overfitting可以看我前面介紹的文章)

Save Image

生成16張假資料並save,就像這樣。

Loss function跟訓練器的設定

這裡我們需要Generator先去產生一個假圖片,這裡我們把real跟fake的資料分開進鑑別器(discriminator)去預估出是真是假的預估值,這裡要注意的是D_fake那邊必須要開reuse,不然會有error(我當時debug超久…),loss function就照論文上的演算法設定,前面都要加個負號,不加負號loss會出現nan的狀況…至於為什麼我就不太清楚了QQ

開始訓練

接下來就是寫main function來訓練啦,可以直接看code,很直觀

結果

我們來看一下訓練200 epoch的狀況,看得出來越後面的epoch出來的圖像有越來越像mnist手寫資料

我們來看一下200 epoch 的loss 變化

總結

GAN是一個應用很廣的model,這篇文章介紹的只是最基本的一個GAN架構,近幾年的GAN論文量爆增(嚇,也出現了很多的應用技術,我簡單介紹一下近期常用的應用以及一些GAN衍伸的MODEL

圖像到圖像的翻譯

二種不同Domain的圖像做出轉換的技術(這真的太厲害了啊!!之後一定要來研究一下!!)

Condition GAN

Cycle GAN

文字到圖像的翻譯

輸入文字描述就可以輸出一個圖片跟文章有相關的,這是GAN結合NLP的應用,也是Base on Condition GAN的技術

高解析度成像

其實就是給一張模糊的照片,利用GAN做出高解析度成像的動作。

GAN真的是一個很有趣的技術,我之後有時間的話想要多做一些GAN的實驗,把GAN結合CNN的DCGAN做出來,還有ConditionGAN的影像轉換,之後應該都會做一些不同的應用(用在非Mnist資料集上面)。

參考資料

NTU ADLXMLDS 課程(主要是看這個)

最簡單的GAN實現

一文看懂GAN

GAN paper

大家如果覺得我寫的還可以並且有幫助到你/妳的話,拜託給我一點掌聲吧,這樣我會更努力並更有動力的把自己所學給寫下來的!!

Source: Deep Learning on Medium