2020.04.01
PyTorchによるニューラルネットワークの学習パス( 2層NNからGANまで)
こんにちは。次世代システム研究室のY.R.です。外国人です。
現在ディープラーニング(深層学習)においてPyTorchは非常に注目されているフレームワークである。日本で株式会社Preferred NetworksがPyTorchに移行することは話題になった[1]。
PyTorch自体と周りの発展[2][3][4]及びPyTorchとTensorflowの 比較[5]に興味が持っている方々が引用に参考して良いと思う。
本ブログPyTorchで実装した事例を紹介する。こちらで自分の勉強を皆さんと共有したいと思う。
初めに
ディープラーニングフレームワークとしてPyTorchは大体以下の図に示している通りにいくつか重要なコンポーネントで構造されている。
これからそれぞれのコンポーネントの使い方と関連性を実践で説明する。
実践1: 簡単なニューラルネットワーク
http://cs231n.github.io/neural-networks-1/
先ず、この図のような一つの隠れ層(例えばRelu)で構造されているニューラルネットワークを実装する。こちらでPyTorchによる実装とNumpyによる実装を比べる。
左側はnumpyによる実装であるが、右側はtorchによる実装である。二つの実装を比べると、PyTorchを通じるニューラルネットワークの構築(1)は複雑になるとみられるが、ロジックは明確になれる。PyTorchでニューラルネットワークのパラメータ調整(2と3)は明らかに簡単なことになる。複雑なニューラルネットワークの実装でパラメータの調整はよく一番難しいことになっている。更に、ニューラルネットワークの効果は学習率(learning_rate)の設定に緊密な関連性が持ている。PyTorchに多くの学習率の設定方法が置いてある。
実践2: Long Short Term Memory(LSTM)
時系列予測においてLSTMはよく利用されるニューラルネットワークである。LSTMに原理は参考[6]分かり易く説明していただいてある。
こちらでPyTorchによる実装を簡単に紹介する。
LSTMもPyTorch のnn.LSTMに実装されている。使用者として必要なことはいくつかのパラメーターを指定することだけです。LSTMといる類インスタンス化した後で、相応のLSTMの構造は簡単に展示される。
実はこのLSTMの実装と「簡単なニューラルネットワーク」を並べると、共通点が多いと感じられる。
これから、モデルをトレーニングで重要なステップを紹介する。
簡単なニューラルネットワークと同様にtorch.optimといる最適化ツールを利用した。torch.optimはPyTorchの重要なコンポーネント(gradient based optimization)である。PyTorchの[automatic differentiation]のおかげで、私たちは相応の損失関数と最適化インスタンスを初始化した後で、簡単にニューラルネットワークの調整できる。
実践3: Generative adversarial networks(GAN)
敵対的生成ネットワーク(GAN)はイアン・J・グッドフェロー(Ian J. Goodfellow)に2014年で提出されている。GANはヤン・ルカン(Yann LeCun機械学習業界の名人)により最近の20年間で機械学習の最もクールなアイデアであると呼ばれた[7]。更に、GANも様々なエリアで応用されている。特に、コンピュータビジョンにおいてGANは本物みたいな画像を生成できる。
簡単に説明したら、GANは以下のようにゼロサムゲームフレームワークで互いに競合する2つのニューラルネットワーク(こちらで犯罪者と探偵)のシステムによって実装される。トレーニングのイテレーションによって両方も強いモデルになれる。
GANに関して深く理解したいなら、参考[8]でGANの原理をはっきり紹介されたと思う。
これから簡単な例を通してGANを実装してみる。
こちらで複雑なデータ生成の場合ではなくで、以上の図で標準正規分布からだんだん指定された正規分布を生成してみる。
GANの紹介のように、実装に二つのニューラルネットワーク(GとD)がある。二つのニューラルネットワークもそれぞれに最適した。気になることは、今度nn.Sequentialを駆使することにしてみたである。ニューラルネットワークの実装は易くなった。
この図は指定された正規分布(平均値:5, 分散:2)である。
この図は生成された正規分布である。本物の正規分布と比べていると、二つの部分はよく似ている。平均値と分散は以下の通りである。
Real Dist ([4.9714283221028746, 2.0351741871237734]),
Fake Dist ([5.01533369487524, 1.9048349195749072])
本物の分散拡大しながらも、生成された分散と本物の差別は明らかになっていると見つけていた。こちらで本物は単なる正規分布であるが、こんな状況も事前知られておく。こんなことは本番のエリアであり得ないと思う。
感想
今度のブログで自分はPyTorchを勉強していることを共有している。単なる二層のニューラルネットワークから注目されている敵対的生成ネットワークまで三つのニューラルネットワークを実装したこと紹介した。実装のソースも公開され[8]。
PyTorchとNumpyの相互運用性非常に便利であると感じる。動的計算グラフという特徴(Define-by-Run)のおかげで、デバッグも難しくないと思う。
現在ディープラーニングに関する研究また速く進行している、新たなモデルが次々に提出されている。研究業界に大人気のフレームワークとして、PyTorchを常に利用されている。相関の研究に興味が持っている方々はPyTorchで検証し易くなると思う。PyTorchを巡り、様々なツール[9] [10] [11]も提出されているし、素晴らしいチュートリアルも公開されている[12][13][14]。PyTorchが広く使われていくと信じている。PyTorchに興味がある方々とコミュニケーションを期待して頂く。
最後に
次世システム研究室では、ビッグデータ解析プラットホームの設計・開発を行うアーキテクトとデータサイエンティストを募集しています。興味を持って頂ける方がいらっしゃいましたら、ぜひ 募集職種一覧からご応募をお願いします。
一緒に勉強しながら楽しく働きたい方のご応募をお待ちしております。
グループ研究開発本部の最新情報をTwitterで配信中です。ぜひフォローください。
Follow @GMO_RD