2023.05.15

入力トークン数オーバー200万
Recurrent Memory Transformer
Is Attention All You Need? Part 2

TL;DR

  1. Recurrent Memory Transformer(RMT)というモデルが、200万トークンもの長い入力にも対応しました。GPT-4の上限(32,000トークン)と比較すると飛躍的な長さといえます。
  2. RMTでは、入力を細かいセグメントに分割しメモリー・トークンを付加して順次処理するモデルで長時間の記憶を保持するという特殊なタスクに対して有効でした。
  3. ただし、個別のタスクにファイン・チューニングした結果であり、GPTのような汎用的な文章生成タスクではない点に注意が必要です。

Is Attention All You Need?(Again)

こんにちは、グループ研究開発本部・AI研究室のT.I.です。前回のBlogでは、ChatGPTなどLLM(Large Language Model)の性能の鍵であるAttentionを超える可能性のあるモデル(S4)を紹介しました。
さて、現状はどうなっているでしょうか?再び、“Is Attention All You Need?”
をアクセスしてみますと、このような感じになっていると思います。

Remaining Time:



やはり、皆さんご存知のようにAttention(Transformer)は、まだ必要とされている様子です。前回のBlogで紹介したようにTransformerの弱点は、長い入力長には対応できないことです。Attentionとは、Query(Q)に対してKey(K)との重なりを計算し(softmax関数を作用させ)、その重み付けしたValue(V)を出力する関数
$$
\mathrm{Attention}(Q, K, V) \sim \mathrm{softmax} \left(
QK^T \right)V
$$
で定義されます。そのため計算コストが入力(\(L\))に対して長さの二乗(\(\mathcal{O}(L^2)\))で増加してしまいます。GPT-4では、3,200トークンの入力限界があり(OpenAI GPT-4)、CoLT5では、6,400トークンまでの入力に対応しています。
以下のアニメーションは入力された単語間のAttentionの様子を模式的に表したものです。実際のモデルでは、複数のAttention層が重なっており(Multi-head attention)、その出力結果を更に何度も繰り返しMulti-head Attentionに通すことで性能を高めています。

以前、このブログではこの長い入力長に対応する一つの手法として、Recurrent Neural Networkを使った「S4」モデルを紹介しました。しかしここに来て、新たなモデル「Recurrent Memory Transformer」(RMT)が登場しました。名前からもわかる通り、Transformerを利用しています。そして何と、このモデルは最大200万トークンもの長い入力に対応できることが実証されました。

GPT-4, CoLT5, RMTの入力トークン数の比較
一体に何が起きたのでしょうか?今回のブログでは、この「Recurrent Memory Transformer」(RMT)モデルについて詳しく紹介します。ちなみにですが、Harry Potterシリーズ(7巻)は英語で108万ワード(約145万トークン)らしいので、200万トークンというと、余裕で全部読み込めます。(A Song of Ice and Fire(ゲーム・オブ・スローンズの原作)は、未完ですがすでに200万トークン越え)

小説のトークン数 (100 token = 75 wordsで換算)。
参考資料: Word counts of the most popular books in the world

Recurrent Memory Transformer

まずは、Recurrent Memory Transformerとは、どのようなモデルなのでしょうか?
簡単に述べると
  1. 入力を短いセグメントに分割
  2. 特別なメモリー・トークン(論文に合わせ図などではmemと表記)を付加してTransformerに入力
  3. 出力でもメモリー・トークン込みで出力
  4. 再度、このメモリー・トークンを次の入力に追加して入力
することを順次繰り返します。Recurrent Memory Transformer の動作をアニメーションで著したものが以下の図です。学習時にはこのメモリー・トークン込みで逆方向へのフローとなります。



具体的に数式で解説すると以下の通りです。最初に提案されたRMT(“Recurrent Memory Transformer”, Advances in Neural Information Processing Systems, volume 35, pages 11079-11091.)では、Transformer Decoderも利用しているためメモリー・トークンをReadとWriteと2種類に分けて利用していました。今回のものはBertのようにTransformer Encoderのみなので、メモリー・トークンは1種類だけ利用します。
具体的には\(H_\tau^0\)を元の入力のセグメントとして、メモリー・トークン\(H_\tau^{mem}\)を以下のように添付します。
$$
\tilde{H}_\tau^0 = [H_\tau^{mem} \circ H_\tau^0]
$$
これをTransformer(\(N\)層)で処理した結果
$$
\mathrm{Transformer}(\tilde{H}_\tau^0)
= [\bar{H}_\tau^{mem} \circ H_\tau^N]
$$
この処理によって、セグメント\(\tau\)での情報を\(\bar{H}_\tau^{mem}\)に書き込みます。次の\(\tau + 1\)のセグメントも同様にメモリー・トークンを付加して、順次処理します。
$$
\tilde{H}_{\tau+1}^0 = [H_{\tau+1}^{mem} \circ H_{\tau+1}^0]
$$
ここで、\(\bar{H}_\tau^{mem} = H_{\tau+1}^{mem}\)として1つ前のステップのメモリー・トークンの出力結果を再起的に利用します。

通常のTransformerでは、入力の二乗に比例して計算コストが増加します。RMTとOPT(Open Pre-trained Transformer)の入力長さと計算コストを可視化したものが次の図です。


Transformer(OPT)とRMTの入力長に対する計算量の比較(論文の図を元に作成)

OPTのパラメータ数(125M、...、175B)が増加するにつれて、入力長に対する計算コストの増加は緩やかではありますが、最終的には二乗に比例して急速に増大します。一方、RMTでは、短いセグメントに分割し、一定の長さのTransformerで繰り返し処理することで、計算コストの増加は線形に抑えられます(自明ではありますが)。グラフを見ると差があまりないように感じるかもしれませんが、縦軸のスケールが対数であることに注意してください。

Memory-intensive Synthetic Tasks

この論文では、RMTの性能評価として、簡単な記憶タスクを実行しました。そもそもの問題、そんなに長い入力のテストデータがあるはずもないので、2種類のデータセットを合成して作成しております。1つ目のデータセット(bAbI
dataset
)は、記憶させるべき事実として利用され、もう一方のデータセット(QuALITY long QA dataset)は、回答に無関係なノイズとして組み合わせました。具体的なタスクは以下の3つです。
    1. Fact Memorization
    2. Fact Detection and Memorization
    3. Reasoning with Memorized Facts
Fact Memorizationタスクでは、最初に事実を読み込み、その後はノイズが続き、最後に質問が投げかけられて回答させる形式です。長期間の記憶を保持できるかが重要となります。具体的には以下のアニメーションのようになります。
      • Fact: Daniel went back to the hallway.
      • Question: Where is Daniel?
      • Answer: hallway



Fact memorization taskの概要(論文の例を参考に作成)

Fact Detection and Memorizationタスクは、Fact Memorizationに似ていますが、事実が入力の途中でランダムに現れます。ノイズの中から適切に事実を見つけ出し、その記憶を保持する能力が求められます。具体的には以下のアニメーションのようになります。



Fact Detection and memorization taskの概要(論文の例を参考に作成)

Reasoning with Memorized Factsタスクでは、2つの事実を元に質問に回答します。1つ目の事実は最初に与えられますが、2つ目は途中でランダムに投入されます。長時間の記憶の保持と2つ目の事実を判断する能力が問われます。
      • Fact1: The hallway is east of the bathroom.
      • Fact2: The bedroom is west of the bathroom.
      • Question: What is the bathroom east of?
      • Answer: bedroom



Reasoning with Memorized Facts task の概要(論文の例を参考に作成)

これらのタスクを通じて、RMTの記憶と推論能力が評価されます。ノイズの中から適切な情報を抽出し、記憶を維持しながら質問に回答することができれば、RMTの性能が実証されることになります。

Performance

今回の実証実験では、bert-base-cased model from HuggingFace Transformer(参考: bert-base-cased)を使用しました。詳細については、こちらのGitHubリポジトリ(https://github.com/booydar/t5-experiments/tree/scaling-report)で公開されています。学習には4-8台の Nvidia 1080ti GPUを利用し、長い入力長での評価時には40GB Nvidia A100を1台で実施しました。RMTでは、全512トークンの入力のうち、10トークンをメモリとして使用し、残りの499トークンを入力セグメントとして扱います。また、残りの3トークンはセパレータとして活用します。このモデルを利用したAttentionについては、こちらのexBERTで実際に触ってみることができます。単純に1つの単語が1つのトークンに対応している訳ではないので注意してください。

次に、異なる長さのセグメント(1から7)で学習を行い、それぞれ異なる長さのセグメント(1から15)で評価を行った結果について説明します。学習に使用した入力長が短すぎると精度が低下しますが、ある程度の長さ(4から5セグメント)で学習を終えると、それ以上の長さの入力でも精度が向上することが確認されました。



RMTで学習データのセグメント長ごとの評価精度のまとめ。
(論文中の図を元に作成)
こうして学習したモデルの入力を更に4,096から2,043,904トークン(つまり64セグメントから4,096セグメント)まで拡大した最終的なタスクごとの性能評価(Accuracy)のまとめが以下の図になります。入力が長くなるほどに性能は徐々に低下しますが、それでも200万トークンもの長い入力に対しても一定以上の水準で回答しています。個別のタスク別に難易度を比較してみると、最初の事実を最後まで保持して回答する必要があるFact Memorizationよりも、途中の事実を検出して回答するFact Detection and Memorizationの方が精度が良いです。また、最も複雑なタスクであるReasoning with Memorized Factsの性能が他と比較して低くなっています。
RMTの長い入力長で評価した場合の性能のまとめ(論文中の図を元に作成)

最後にRMTの記憶のメカニズムが適切に働いていることを示唆するAttention機構の結果の例を紹介します。これらの図は、Reasoning with Memorized Factsにおける入力に対して、それぞれのステップに応じて、どのトークンが反応しているのか可視化したものです([CLS], [SEP]は入力のセパレータです)。

最初の事実を検出して、メモリー・トークンが強く反応しており、結果を記憶します。



Fact Detectionにおけるアテンション機構の可視化の例(論文の図を元に作成)

ただのノイズの場合、メモリー・トークンは反応せず、記憶は保持されています。



2つ目の事実を検出したため、再びメモリー・トークンが反応し記録されます。



Fact Detectionにおけるアテンション機構の可視化の例(論文の図を元に作成)

最後に質問が投げかけられ、記憶を読み出し回答します。



回答時におけるアテンション機構の可視化の例(論文の図を元に作成)

このように、特定のAttention層の可視化からもRMTが、長期間の記憶を保持してタスクを実行していることが示唆されます。

Summary

今回は最近発表されたRecurrent Memory Transformer(RMT)を紹介しました。従来のTransformerは、長い入力に対して計算コストが飛躍的に増大する問題がありGPT-4では、最大32,000トークンまでですが、一方でRMTでは、200万トークンもの推論タスクをパスできました。

RMTの鍵となる仕組みは、入力を短いセグメントに分割し、それぞれにメモリー・トークンを付加して次の入力に再帰的に引き継ぐことです。これにより、長期間にわたる記憶の保持が可能となり、3つのタスク(Fact Memorization, Fact Detection and Memorization, Reasoning with Memorized Facts)を200万トークンもの入力に対して処理することに成功しました。

このように、少数のメモリー・トークンを再帰的に利用して長期の記憶を保持する手法は非常に興味深いもので、これほどまでに長い入力に対応できるという事実は注目すべき点です。しかし、ここで注意しておきたいのは、これが今回紹介した3つの比較的単純なタスクでの性能であり、例えばChatGPTのように文章を生成するといったより複雑な応用が可能かどうかは、まだ確認されていません。その点については、今後の調査や研究が待たれます。

グループ研究開発本部 AI研究開発室では、データサイエンティスト/機械学習エンジニアを募集しています。ビッグデータの解析業務などAI研究開発室にご興味を持って頂ける方がいらっしゃいましたら、ぜひ募集職種一覧からご応募をお願いします。皆さんのご応募をお待ちしています。

References

    1. “Attention Is All You Need”,
      https://arxiv.org/abs/1706.03762.
    2. “Scaling Transformer to 1M tokens and beyond with RMT”,
      https://arxiv.org/abs/2304.11062.
    3. “Recurrent Memory Transformer&rdquo, Advances in Neural
      Information Processing Systems, volume 35, pages 11079-11091.)
    4. CoLT5: Faster Long-Range Transformers with Conditional Computation,
      https://arxiv.org/abs/2303.09752.
    5. OPT: Open Pre-trained Transformer Language Models,
      https://arxiv.org/abs/2205.01068.
    6. QuALITY: Question Answering with Long Input.
      (https://github.com/nyu-mll/quality)
    7. bAbI
      https://github.com/facebookarchive/bAbI-tasks

  • Twitter
  • Facebook
  • はてなブックマークに追加

グループ研究開発本部の最新情報をTwitterで配信中です。ぜひフォローください。

 
  • AI研究開発室
  • 大阪研究開発グループ

関連記事