2023.06.30
Hyena: 次世代LLMへ向けたTransformerを越える新機械学習モデル
Is Attention All You Need? Part 3
TL;DR
- Hyenaという深層学習モデルが提唱されました。これは、前々回に紹介した状態空間モデルを利用するS4を更に発展させたもので、Transformerと同等の性能を持ちながら入力長の二乗でコストが増える問題を解決しました。
- TransformerのAttentionでは、入力をKey、Query、Valueに変換して処理します。Hyenaではその構造を一般化し、データで制御されたゲートと入力長と同等の長い畳み込みであるHyena Filterによる再帰的処理により、計算コストを抑えつつ、高性能を維持することに成功しました。
- Hyenaは長い入力長に対してTransformerより高速な上にScaling Laws(大規模になるほど高性能)の兆候を示しており、更なる大規模計算への応用が期待されます。
HyenaとGPTの計算量と性能の比較。計算量を増やすほど性能がどこまでも改善していくScaling Lawsの振る舞いがHyenaでも見られ、更にHyenaはGPTより20%ほど少ない計算コストで同等の性能を達成
入力長を増やした場合、Attentionを利用すると計算時間は飛躍的に増大するが、Hyenaの計算時間の増加は緩やかで、入力長が65,536に対しては100倍以上高速
Is Attention All You Need?(Third time)
Hungry Hungry Hippos(H3)
このようなゲームみたいです
State Space Models
$$
\begin{eqnarray}
\dot{x}(t) &=& \mathbf{A}x(t) + \mathbf{B}u(t) \\
y(t) &=& \mathbf{C}x(t) + \mathbf{D}u(t)
\end{eqnarray}
$$
これらの\(\mathbf{A}, \mathbf{B}, \mathbf{C}, \mathbf{D} \)はモデルの学習パラメータとなります。このSSMは入力(\(u_1, \dots, u_N\))に対して出力(\(y_1, \dots, y_N\))を計算するとき、フィルタ
$$
f = [\mathbf{CB}, \mathbf{CAB}, \mathbf{C} \mathbf{A}^2\mathbf{B}, \dots ,\mathbf{C}\mathbf{A}^{N-1}\mathbf{B}]
$$
を使った畳み込み(convolution)
$$
y_i = \mathbf{C}\mathbf{A}^i\mathbf{B}x_0 + (f \ast u)_i + \mathbf{D}u_i
$$
として処理できます。このように変形しておくことで、FFT(Fast Fourier Transformation)などを利用し非常に効率的に計算ができることをS4の紹介で解説しました。なお、以降では簡単のため
$$
y = \mathrm{SSM}_{\mathbf{A,B,C,D}}(u)
$$
と表記します。これらのフィルタ(行列\(\mathbf{A,B,C,D}\))は、モデルの性能を改善するために様々な条件が課されます。
- Induction Head
- Associative Recall
Headでは、入力で特別な記号\(\vdash\)の次の文字を記憶して出力するタスクです。具体的な例を見てみましょう。以下のアニメーションでは、入力の途中に\(\vdash\)が出現し、その後の単語\(f\)を記憶します。最後に出力を求める\(\vdash\)が現れるので、その次の文字\(f\)を出力します。
どちらも入力長と語彙の大きさによってタスクの難易度が増します。そして、これらのタスクを2層のモデルで実行すると、以下のようになります。特にInduction Headについては、SSMを利用したS4DとGated State Spacesの性能は非常に低くなっています。それらと比較してH3では、ほぼ100とS4DとGated State Spacesの弱点を克服し高い性能を発揮しております。なお、Attentionを利用した場合、両者のタスクのスコアは100です。
Synthetic language tasksの性能比較(論文を元に作成)
H3
$$
\mathbf{Q} \odot \mathrm{SSM}_\mathrm{diag}(\mathrm{SSM}_\mathrm{shift}(\mathbf{K}) \odot \mathbf{V})
$$
ここで\(\odot\)は、要素毎の積(アダマール積)です。具体的なイメージを掴むために、H3 Layerの動作を以下のアニメーションで見てみましょう。
$$
Q = u\mathbf{W}_Q, \quad K = u\mathbf{W}_K, \quad V = u\mathbf{W}_V
$$
で定義されます。そして、\(K\)をShift SSMに入力します。Shift SSMでは、\([a,b,c] \rightarrow [0, a, b]\)のような操作を含むSSMで、SSMのパラメータである行列\(\mathbf{A}\)の一部の要素のみが1(\(A_{i,j}=1\) for \(i -1 = j\))でそれ以外はゼロ(\(A_{i,j}=0\) for \(i – 1 \neq j\))となるような行列です。その出力結果と\(V\)の要素積を取り、Diag. SSMに入力します。Diag. SSMの行列\(\mathbf{A}\)は対角行列で、対角化HiPPO(S4D)を利用します。その出力結果と\(Q\)との要素積を取り、出力の射影行列\(\mathbf{W}_O \in \mathbf{R}^{d \times d}\)を掛け合わせて出力とします。このH3の計算コストとしては、\(\mathcal{O}(N\log N)\)であることがこの論文では示されています。これは通常のAttention機構である\(\mathcal{O}(N^2)\)よりも効率的です。なお、論文ではFlashConvというSSMの処理を高速化し学習を効果的にする手法も提案されています。詳細は論文を参照してください。
H3と他のモデルのOpenWebTextを用いたPerplexityの比較(図は論文を元に作成)
SuperGLUEによる各種ベンチマーク結果のまとめ(論文を元に作成)(注) 実際には355Mでは、OPT-350M
推論のスピードの比較(図は論文を元に作成)
Hyena Hierarchy
- Data control: データに制御された線形演算であり、1つのブロックで全ての線形関数を埋め込むことができる。Attention行列\(A\)を用いて、\(y = A(k,q)v \)と書き下せます。入力\(k, q\)で\(v\)についての線形関数を制御していることになります。
- Sublinear parameter scaling: Attention layer のパラメータの数は入力長と独立している。そのため、Attention layer間のニューラルネットワークなどに更にパラメータを追加できる。
- Unrestricted context: 入力のコンテキスト(順序など)に依存せずに任意の要素間の関係性を扱える。
これらの性質を指針としてHyenaは設計されております。
$$
y_t = (h\ast u)_t = \sum_{n=0}^{L-1} h_n u_{t-n}
$$
\(D \)は一般に入力の次元数ですが、簡単のために\(D=1 \)としたSISO(single input single output)の場合を考えますと、この計算はToeplitz kernel matrix \(S_h \in \mathbf{L \times L}\)との行列積
$$
(h \ast u) =
\begin{pmatrix}
h_0 & h_{-1} & \dots & h_{-L+1} \\
h_1 & h_0 & \dots & h_{-L+2} \\
\vdots & \vdots & \ddots & \vdots \\
h_{L-1} & h_{L-2} & \dots & h_0
\end{pmatrix}
\begin{pmatrix}
u_0 \\
u_1 \\
\vdots \\
u_{L-1}
\end{pmatrix}
$$
によって表現されます。Toeplitz行列とは、\(A_{i,j} = A_{i+1,j+1}\)と対角方向に同じ行列のことです。Convolutional Neural Network(CNN)などでは、このフィルタ\(h\)を直接に学習のパラメータとして最適化します。この場合、カーネルの大きさ\(M\)は入力長よりも一般的に小さくなります(\(M \ll L\))。
$$
h_t = \gamma_\theta(t)
$$
ここで、\(\theta\)は、関数\(\gamma_\theta\)のパラメータです。このような間接的な畳み込み(implicit convolution)を利用することで、入力長と同等の長い畳み込みフィルタを扱うことができます。その例が、まさにSSMを利用したフィルタです。
SSM
$$
\begin{eqnarray}
x_{t+1} &=& Ax_t + Bu_t \\
y_t &=& Cx_t + D u_t
\end{eqnarray}
$$
は、以下の畳み込み操作に帰着できます。
$$
y_t = \sum_{n=0}^t (CA^{t-n}B + D\delta_{t-n})u_n
$$
これをフィルタによる畳み込みと見ると、
$$
t \rightarrow h_t = \begin{cases}
0 \qquad &t < 0& \\ CA^t B + D \delta_t \quad &t \geq 0& \end{cases} $$ となります。
Hyena Operators
Hyena Recurrence
$$
\begin{eqnarray}
z_t^1 &=& v_t \\
z_t^{n+1} &=& x_t^n(h^n \ast z^n)_t \quad n \in 1, \dots N \\
y_t &=& z_t^{N+1} \\
\end{eqnarray}
$$
このHyena recurrenceの計算コストは、\(\mathcal{O}(NL\log_2 L)\)になります。この計算をまとめると以下のようになります。
$$
y = x^N \cdot(h^N \ast(x^{N-1} \cdot (h^{N-1} \ast \cdots (x^1 \cdot (h^1 \ast v)))))
$$
Attentionでは、QueryとKeyの重みでValueを処理していました。Hyenaでは、それが\(N\)-stepも続くような機構となっています。これだけでは抽象的すぎて何のことだかわからないので、以下で具体的に各ステップでの演算と特徴を説明します。
Hyena Matrices
$$
\begin{eqnarray}
A(q, k) &=& D_q S_\psi D_k S_\varphi \\
H3(q, k, v) &=& A(q, k)v
\end{eqnarray}
$$
ここでフィルタに相当する\(S_\psi, S_\varphi\)は、\(L \times L\)のToeplitz行列で学習パラメータです。具体的に書き下すと
$$
S_\psi =
\begin{pmatrix}
\psi_0 & 0 & \cdots & 0 \\
\psi_1 & \psi_0 & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
\psi_{L-1} & \psi_{L-2} & \cdots & \psi_0 \\
\end{pmatrix}, \quad
S_\varphi =
\begin{pmatrix}
\varphi_0 & 0 & \cdots & 0 \\
\varphi_1 & \varphi_0 & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
\varphi_{L-1} & \varphi_{L-2} & \cdots & \varphi_0 \\
\end{pmatrix}
$$
となります。因果性を持たせるために、\(S_\psi, S_\varphi\)の上三角の要素は0とします。これを一般化すると
$$
y = H(u)v = D_x^N S_h^N \cdots D_x^2 S_h^2 D_x^1 S_h^1v
$$
とHyena recurrenceをHyena行列の積として表現できます。ここで、\(D_x^n = \mathrm{diag}(x^n) \in \mathbf{R}^{L \times L}\)、\(S_h^n\)はフィルタ\(h^n\)の要素からなるToeplitz行列です。
Hyena Filters
$$
h_t = \mathrm{Window}(t) \cdot (\mathrm{FFN} \circ \mathrm{PositionalEncoding})(t)
$$
ここでPositionalEncodingとは、位置情報を付与する処理です。Hyenaでは、\(\rho_k(t) = e^{i2\pi k t/L}\) (\(k \in 0, \dots, K-1\))として、以下のように定義します。
$$
\mathrm{PositionalEncoding}(t)
= [t, \mathrm{Re}[\rho_0](t), \mathrm{Im}[\rho_0](t), \dots, \mathrm{Re}[\rho_{K-1}](t),
\mathrm{Im}[\rho_{K-1}](t)]
$$
これにより、元々の入力長に加えて\(D_e = 2K+1\)の次元が追加されます。また、Window関数として、
$$
\mathrm{Window}(t) = \exp\{-\alpha t\}
$$
を利用します。
- Hyena Projection: 入力\(u \in \mathbf{R}^{L \times D}\)に対して、\(x^1, \dots, x^N, v \in \mathbf{R}^{D \times L}\)を計算する。
$$x^1, \dots, x^N, v = \mathrm{Projection}(u)$$ - Hyena Filter: \(h^1, \dots h^N \in \mathbf{R}^{D \times L}\) を計算する。
$$h^1, \dots, h^N = \mathrm{HyenaFilter}(L, D_e)$$ - Forward pass: Hyena FilterをFFTを利用した高速なアルゴリズムで順次畳み込み計算していく。
$$
v_t \leftarrow x_t^n \cdot \mathrm{FFTConv}(h^n, v)_t
$$
を順次計算(\(n = 1, \dots, N\))し、最後の\(v (= y)\)として出力。
最終的な計算コストは入力 \(u \in \mathbf{R}^{L \times D}\)に対して、order-N Hyena演算子は
$$
\mathcal{O}(NDL(\log_2 L + D))
$$
であり、入力長の二乗未満でスケールしております。
Performance of Hyena
Hyenaと他のフィルタリングの手法の比較(論文を元に作成)
Hyenaと他のモデルの学習時のPerplexityと計算コストの比較(論文を元に作成)
HyenaとGPTのPerplexityのScaling Laws(図は論文を元に作成)
HyenaとGPTNeo、RWKVのSuperGLUEのパフォーマンスの比較(論文を元に作成)
HyenaとAttention、FlashAttentionの入力長を変えた場合の計算時間の比較(論文を元に作成)
Hyenaを利用した画像分類タスクの性能比較(図は論文を元に作成)
Summary: All We Need Is “Attention”
References
- “Attention Is All You Need”, https://arxiv.org/abs/1706.03762
- “Hungry Hungry Hippos: Towards Language Modeling with State Space Models”, https://arxiv.org/abs/2212.14052
- 【論文メモ】Hungry Hungry Hippos: Towards Language Modeling with State Space Models https://yuiga.dev/blog/posts/hungry_hungry_hippos_towards_language_modeling_with_state_space_models/
- [Journal club] Hyena Hierarchy: Towards Larger Convolutional Language Models
- Hyena Hierarchy: Towards Larger Convolutional Language Models
- “Hyena Hierarchy: Towards Larger Convolutional Language Models”, https://arXiv/abs/2302.10866
// Set the date we're counting down to var countDownDate = new Date("Jan 1, 2027 12:00:00").getTime();
// Update the count down every 1 second var x = setInterval(function () {
// Get today's date and time var now = new Date().getTime();
// Find the distance between now and the count down date var distance = countDownDate - now;
// Time calculations for days, hours, minutes and seconds var days = Math.floor(distance / (1000 * 60 * 60 * 24)); var hours = Math.floor((distance % (1000 * 60 * 60 * 24)) / (1000 * 60 * 60)); var minutes = Math.floor((distance % (1000 * 60 * 60)) / (1000 * 60)); var seconds = Math.floor((distance % (1000 * 60)) / 1000);
// Display the result in the element with id="demo" document.getElementById("demo").innerHTML = days + "d " + hours + "h " + minutes + "m " + seconds + "s ";
// If the count down is finished, write some text if (distance < 0) { clearInterval(x); document.getElementById("demo").innerHTML = "EXPIRED"; } }, 1000);
グループ研究開発本部の最新情報をTwitterで配信中です。ぜひフォローください。
Follow @GMO_RD