2020.04.15

強化学習(World Model)の紹介

Pocket

背景

こんにちは、次世代システム研究室のK.Nです。今日は、強化学習の比較的新しいモデルである、WorldModelについて紹介します。(本家のHPがかなり充実しているので、実際のところ、そちらで十分な感じもありますが..)。以降、計算結果までのの段落で使わる図はすべて、そこから引用しています。

こちらで紹介したように、強化学習の手法は大まかに「モデルあり」、「モデル無し」で分けることができます。ここでいう、モデルとは環境に対してであり、ある状態と行動を所与としたときの次時点の状態の条件付き確率を意味します。有名な強化学習の手法であるDQNPPO2等は「モデルフリー」になります。

モデルフリーの強化学習は、環境構造に対する理解を試行錯誤した経験によって習得します。そのため、多くのサンプルが必要となり、学習が非効率です。また、観測データをそのまま環境に対する情報として扱っているため、環境の変化やタスクが変わったときにうまく対応できない、つまり汎化性能が低いという問題があります。反面、本来知り得ない環境をモデル化することにより発生するバイアスが生じる問題はなくなります。

WorldModelは、これらの問題を解決するために、環境に対する抽象的なモデルを構築し、そのモデル上で強化学習を行います。環境のモデルに対する学習は強化学習のタスクとは独立に行うことができるので、転移学習のような応用も可能だと考えられます。

また、WorldModelの構成は以下の図のようなVMCの3つの部品から構成されます。このモデル構成こそが今回の肝になります。

  • V(Vision) 抽象表現の学習
  • M(Memory)  将来時点の状態を予測
  • C(Controller) MVの結果をもとに、行動を選択。

大変シンプルで役割もはっきりして分かりやすいです。具体的には、VVAEを、MRNNを、そしてCには単純な線形関数を利用しています。このことからも分かる通り、世界モデルの構築(抽象表現の獲得とその将来予測)こそが重要であり、世界モデルが十分正確ならば、強化学習のタスク自体はより単純なモデルで処理できるというのが今回提案された手法の思想です。この手法の考え方には、人間の認知方法があります。

それは、我々は目の前にある現実世界を、完璧なイメージを描いているわけではなく、抽象化した内部モデルを通して捉えています。そして、常日頃、脳の中の内部モデルを使って、将来の刺激を予測しながら、行動していると言われています。その意味では、WorldModelでは、時間的な表現をRNNで、空間的な表現をVAEで大規模な内部モデルを構築し、タスクの実行自体は小規模なモデルで表すことにより、個々の役割をはっきりさせています。また、VAEではなくGANPCAで使うであったり、RNNではなく1D-CNNを使うといったことも応用できる可能性は高いと思います。


モデルの構成

具体的に、モデルの構成について見ていきます。ここでは、強化学習のタスクとしてTVゲームを考えています。


VAE (V) Model

強化学習の行動主体(エージェント)は毎時点状態として画像を受け取ります。VAEでは、その画像を圧縮した抽象ベクトルzを学習します。VAEの学習では抽象ベクトルから復元した画像が元の画像に一致するようエンコーダ、デコーダのパラメータを変化させます。

MDN-RNN (M) Model

Vで圧縮した抽象ベクトルと行動から次の時点の圧縮ベクトルを予測するために、RNN(LSTM)を使います。

さらに、混合ガウス分布(MDN)を導入することで、確率的な振る舞いを表現しています。RNNの出力値をMDNのパラメータとしています。MDNで確率分布p(z)の多峰性が表現できるので、考えうる将来の複数のシナリオを自然な形で表されます。

 

Controller (C) Model

Cでは、強化学習のタスクの実現を担っています。つまり、期待される累積報酬が最大となるような行動を取るようになるよう学習します。このモデルは抽象ベクトルzとRNNの隠れベクトルhをインプットとして、単純な1次式で表されます。

World Model

MVCの相互関係を次図に示す。


下図がOpenAIGym環境上のWorldModelの模擬コードである。


VとMは通常のDeepLearningモデルと同様、back propagationで微分値を計算し、最急降下法ベースの最適化で学習される。Cは単純なモデルゆえ、より野心的な最適化手法、今回は共分散行列適応進化選択(CMA-ES)を採用している。この手法は進化戦略の一つであり、より大域的な探索が期待される。

計算結果

本家にある通り、カーレースを題材にWorldModelによる強化学習を行ってみたいと思います。実装も公開されているため、それを参考にしていただきました。

手順

今回のカーレースの学習の簡単な手順は次のとおり。
  1. ランダムな方策で、2000エピソードのサンプルを集める。(論文では10,000だが、今回は成約上少なく)
  2. 集めた画像を32次元の抽象ベクトルに圧縮するようVAEを学習
  3. MDN-RNN (M) を学習
  4. 1レースの累積報酬が最大となるようコントローラ(C)を学習。
再度強調すると、1-3では強化学習のタスク自体は行っていない。つまり、世界モデルの学習自体には、強化学習のタスクは無関係である。報酬をデータとして用いるのは4の段階になってからである。

学習結果

まず、ランダムな方策でカーレースをすると、次のような具合になります。もちろん、ランダムなのでうまく走れません。下の動画は、ランダム方針の元のレース動画及びVAEにより圧縮後、再度復元した動画になります。おおよそ再現されている様子がわかります。



 

続いて、下の動画は学習されたRNNに対して、初期時点の抽象ベクトルのみを与え、その後予測される抽象ベクトルをVAEによって復元し、描写したものです。カーレースは開始後すぐに左カーブに入ることが多いですが、その事実を反映した結果になっております。



つづいて、肝心のコントローラ(C)を学習させ、無事強化学習のタスクを攻略させることができたかどうかを確認したかったのですが、間に合わなかったため残念ながら今回はここまでとします。


最後に

次世システム研究室では、ビッグデータ解析プラットホームの設計・開発を行うアーキテクトとデータサイエンティストを募集しています。興味を持って頂ける方がいらっしゃいましたら、ぜひ 募集職種一覧からご応募をお願いします。
一緒に勉強しながら楽しく働きたい方のご応募をお待ちしております。