Transformerのマルチヘッドアテンションに関する論点

はじめに

本論文では、マルチヘッドアテンション(Multi-Head Attention, MHA)の仕組み、計算フロー、およびその利点と限界について解説します。マルチヘッドアテンションは、Transformerモデルの中核を成す技術であり、自然言語処理において非常に有効な手法です。この論文では、マルチヘッドアテンションの各要素の詳細、シーケンス全体での計算方法、そしてヘッド数を増加させる際の考慮事項について述べます。

マルチヘッドアテンションの概要

マルチヘッドアテンションとは、入力トークンの埋め込みベクトルを複数の「ヘッド」に分割し、各ヘッドが独立してアテンションを計算する手法です。各ヘッドがシーケンス全体に対してアテンション計算を行うことが特徴であり、各ヘッドが異なる特徴空間に対して独自に情報を処理します。これにより、各トークン間の関係を多角的に評価することが可能です。

各ヘッドにおける計算の仕組み

クエリ、キー、バリューの生成

マルチヘッドアテンションの各ヘッドでは、入力トークンの埋め込みベクトル(例:512次元)に対して専用の重み行列を用いてクエリ(Q)、キー(K)、バリュー(V)を生成します。埋め込み次元が512、ヘッド数が8の場合、各ヘッドは64次元のクエリ、キー、バリューを持ちます。これにより、元の埋め込みベクトルを小さな次元に分割し、各ヘッドが異なる情報を抽出することが可能になります。

アテンションスコアの計算

各ヘッドは、自身のクエリとキーの内積を計算し、スケールした後にソフトマックスを適用してアテンションスコアを得ます。アテンションスコアはシーケンス全体のトークン間の関係を示しており、各トークンが他のトークンに対してどれだけ重要であるかを計算します。この計算は各ヘッドが独立して行うため、同じシーケンスに対しても異なる視点で情報を評価できます。

アテンション出力の生成

アテンションスコアを用いて、バリューに重みをかけて加重平均することで最終的なアテンション出力を得ます。この出力は、たとえば10トークン × 64次元の行列として表されます。各ヘッドのアテンション出力は異なる特徴を捉えるため、それぞれの出力が重要です。

ヘッドの連結と最終出力の生成

全てのヘッドのアテンション出力を連結(コンカット)し、1つの大きな行列にします。この連結された行列を線形変換にかけることで、元の埋め込み次元(例:512次元)に再構成します。このプロセスにより、各ヘッドで抽出された異なる特徴が統合され、モデルがより豊かな表現を生成することができます。この連結操作は、水平連結として実行され、各ヘッドの出力を並べて大きな次元のベクトルに変換します。

シーケンス全体に対するアテンション計算

マルチヘッドアテンションでは、次元を分割した各ヘッドがシーケンス全体に対して総当たりでアテンション計算を行います。つまり、各クエリが全てのキーと比較されるため、シーケンス全体の文脈が考慮されます。これにより、各トークンがシーケンス内の他のすべてのトークンとの関係性を持つことが可能となります。

  • シーケンス全体の参照:たとえば、10個のトークンがある場合、各ヘッドは全てのクエリと全てのキーのペアを計算します。このようにして、各ヘッドはシーケンス全体の文脈を反映した情報を生成します。
  • 異なる視点の提供:各ヘッドが異なる重み行列を使用するため、同じシーケンスに対しても異なる特徴を抽出し、多様な情報を提供します。

ヘッド数が多すぎる場合の問題点

ヘッド数を増やしすぎるといくつかの問題が発生します。具体的には以下の通りです:

  1. 各ヘッドの次元が小さくなりすぎる
  • ヘッド数を増やすと各ヘッドの次元が小さくなり、十分に複雑な特徴を捉えられなくなる可能性があります。たとえば、埋め込み次元が512でヘッド数が32の場合、各ヘッドの次元は16となり、表現力が制限されます。
  1. 計算コストとメモリ使用量の増加
  • ヘッド数が増えると計算量やメモリ使用量も増加します。これにより、トレーニングや推論の効率が低下する可能性があります。
  1. ヘッド間の冗長性
  • 多くのヘッドが同様のパターンを学習してしまうことがあり、冗長性が発生します。その結果、モデルの効率が下がり、必ずしも性能が向上しない場合があります。

結論

マルチヘッドアテンションは、各ヘッドがシーケンス全体に対してアテンションを計算し、多様な特徴を抽出することで、モデルの表現力を大幅に向上させる手法です。次元を分割して異なる特徴空間で計算することにより、同じ入力に対して異なる視点を提供し、シーケンス全体の文脈を捉えることができます。

しかし、ヘッド数を過剰に増やすと各ヘッドの次元が小さくなりすぎて表現力が低下することや、計算コストやメモリ使用量の増加、さらに冗長性の発生といった問題が生じる可能性があります。そのため、埋め込み次元とのバランスを考慮し、タスクに応じて適切なヘッド数を選定することが重要です。実験的に調整を行うことで、最も効果的な設定を見つけることが推奨されます。