• このエントリーをはてなブックマークに追加

fast.ai Lesson10のテーマはGAN.

課題は、PyTorchで書かれたWGANを改善せよとのこと。さすがは、Cutting Edgeだ。GANもDCGANも通り越して、WGANに行ってしまう。

PyTorchもGANもよくわからない自分にはツライ。まずは、WGANの前にPyTorchとGANからはじめることにした。

まずは、GANの開祖である以下の論文に目を通した。

スポンサードリンク

PyTorch first inpression

軽くPyTorchのチュートリアルと fast.aiの Jupyter Notebookを眺めたあと、PyTorchに挑戦!numpyを扱うみたいで書きやすい。

PyTouchの特徴は以下のようだ。

Kerasだと、簡単に書くために細かい部分はライブラリに覆い隠されているけれども、PyTorchは多少なりともむき出しになっているので、細かいカスタマイズがしやすそうない印象を受けた。Kerasと PyTorch、両方使えるようになりたい。

GANを実装してみる

GANの実装は、Kerasバージョンがfast.aiから提供されている。

これをPyTorchに置き換える。パラメータを参考にしつつ、また 公式の DCGANチュートリアルの実装も参考にしつつ、実装してみた。

ぜんぜんダメじゃん、ジェレミー先生!

これを mode collapse というらしい。

なんど試してみても、損失関数の D値が0に収束、G値が大きくなっている。

Dは Discriminaterを表し、GのGeneratorが生成したものが本物か偽物かを判定する役割があるのだが、今起こっている状況は、Generator が生成した画像がほとんどの確率で偽物と判定されている。

GANを改善してみる

How to Train a GANというNIPS2016での発表があって、ここにGANの改善方法がまとまっている。

以下を試してみた。

  • ReLU の 代わりにLeakyReLUを使う。
  • BatchNormalizationを使う。
  • Adam の 学習率を小さくする。
  • ノイズは正規分布からサンプリングする。
  • ネットワークのニューロン数を変更する。

結果。

おっ、それらしく古代文字っぽいものが浮き出てきたぞ。

単純なGANだと、あまり成果がでないことはわかっているので、実験は早めに切り上げて次のステップ DCGANに進むことにする。