Transformerモデルのデコーダによる最終的なトークンの選択

最終的なトークンの選択は、Transformerモデルのデコーダが出力する確率分布に基づいて行われます。このトークン選択のプロセスは、以下のステップに分けて詳しく説明します。これには、ソフトマックス関数による確率の計算、確率に基づく選択アルゴリズム、そして実際のトークン決定が含まれます。

1. 確率分布の生成とソフトマックス

トークンの選択は、Transformerデコーダが生成する確率分布から行われます。デコーダはエンコーダの出力および過去のターゲットシーケンスを用いて、語彙全体に対する各トークンの「ロジット(生のスコア)」を計算します。このロジットは、それぞれのトークンが次に来る可能性を数値的に表現しています。

ステップ1: ロジットの計算

  • 入力:デコーダにはエンコーダの出力(文脈ベクトル)と、すでに生成されたトークンが入力されます。
  • 計算:デコーダはこれらの情報を基に、次に来る可能性のあるトークンに対してスコアを計算します。
  • ロジットの出力:各トークンに対するロジット(スコア)が計算されます。これは語彙のサイズを持つベクトルです(例えば、10,000単語からなる語彙なら10,000次元のベクトルになります)。

ステップ2: ソフトマックス関数の適用

ロジットから確率を得るために、ソフトマックス関数を適用します。

  • ソフトマックス関数は各トークンのロジットを0から1の範囲の確率に変換し、その合計が1になるように調整します。 数式で表すと次のようになります: [
    P(y_t = i | y_{<t}, X) = \frac{\exp(\text{logit}i)}{\sum{j} \exp(\text{logit}_j)}
    ] ここで、(\text{logit}i) は語彙のトークン (i) に対するスコア、(y{<t}) はこれまでのターゲットシーケンス、(X) はエンコーダの出力です。この結果が、各トークンの選択に使用される確率分布です。

2. デコーディング手法に応じたトークンの選択

確率分布が得られた後、次に来るトークンを選ぶためのプロセスが行われます。このプロセスは、使用するデコーディング手法に依存します。以下に、トークン選択の主要なアルゴリズムの具体的な動作を説明します。

a. グリーディーデコーディング(Greedy Decoding)

グリーディーデコーディングでは、単に最も高い確率を持つトークンを選択します。

  • 手順
  • 確率分布の中で最も高い確率を持つトークンを検索します((\arg\max)操作)。
  • 例えば、語彙サイズが10,000の場合、その中で最も高い確率を持つインデックスを選びます。
  • このインデックスに対応するトークンが次のトークンとして選ばれます
  • 実装例(Pythonでの擬似コード)
  # 確率分布を表すテンソル
  probs = torch.softmax(logits, dim=-1)

  # 最も高い確率を持つトークンのインデックスを取得
  next_token = torch.argmax(probs, dim=-1)

b. ビームサーチ(Beam Search)

ビームサーチでは、次に来るトークンを複数の候補シーケンスの中から選択します。このプロセスは複数のビーム(候補シーケンス)を並行して追跡し、それぞれについての確率を計算します。

  • 手順
  1. ビーム幅を設定(例:(k = 3))。
  2. 確率が高い上位 (k) 個のトークンをそれぞれ選択。
  3. 各ビームに対して次のステップで再び確率を計算し、それを続けて最も確率の高いシーケンスを選ぶ
  • 実装例(Pythonでの擬似コード)
  # 確率分布を取得
  probs = torch.softmax(logits, dim=-1)

  # 上位k個のトークンとその確率を取得
  top_k_probs, top_k_indices = torch.topk(probs, k=beam_width)

  # ビームごとに確率を更新し、次のトークンを選択する

c. トップKサンプリング(Top-K Sampling)

トップKサンプリングでは、上位の (K) 個のトークンの中からランダムにトークンを選びます。

  • 手順
  1. 確率分布をソートし、上位 (K) 個のトークンを取得。
  2. その中から確率に応じてランダムにトークンをサンプリング
  • 実装例(Pythonでの擬似コード)
  # 確率分布を取得
  probs = torch.softmax(logits, dim=-1)

  # 上位K個を取得
  top_k_probs, top_k_indices = torch.topk(probs, k=K)

  # ランダムに次のトークンをサンプリング
  next_token = top_k_indices[torch.multinomial(top_k_probs, 1)]

d. トップPサンプリング(Top-P Sampling)

トップPサンプリング(Nucleus Sampling)では、累積確率が閾値 ( P ) 以上になるまでトークンを選び、その中からランダムにトークンを選択します。

  • 手順
  1. 確率分布を降順にソートし、累積確率を計算。
  2. 累積確率が閾値 ( P ) を超えるトークンセットを選択。
  3. その中からランダムにトークンをサンプリング
  • 実装例(Pythonでの擬似コード)
  # 確率分布を取得
  probs = torch.softmax(logits, dim=-1)

  # 確率を降順にソート
  sorted_probs, sorted_indices = torch.sort(probs, descending=True)

  # 累積確率を計算
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

  # 閾値Pを超えるトークンセットを選択
  top_p_mask = cumulative_probs <= P
  valid_indices = sorted_indices[top_p_mask]

  # ランダムにトークンをサンプリング
  next_token = valid_indices[torch.multinomial(probs[valid_indices], 1)]

e. 温度スケーリング(Temperature Scaling)

温度スケーリングでは、温度パラメータ ( T ) を使って確率分布の鋭さを調整し、その後にトークンをサンプリングします。

  • 手順
  1. ロジットを温度 ( T ) で割ってからソフトマックスを適用。
  2. 確率分布からトークンをランダムに選択。
  • 実装例(Pythonでの擬似コード)
  # 温度Tを使用してロジットを調整
  adjusted_logits = logits / temperature

  # ソフトマックスを適用して新しい確率分布を取得
  adjusted_probs = torch.softmax(adjusted_logits, dim=-1)

  # ランダムにトークンを選択
  next_token = torch.multinomial(adjusted_probs, 1)

3. 選ばれたトークンの出力

最終的に選ばれたトークンは、現在の出力シーケンスに追加されます。次のトークンを生成するために、これまでに生成されたトークンを含む更新されたシーケンスが再びデコーダに入力されます。このプロセスは、終了トークン(例:\<eos>)が生成されるまで、または事前に設定された最大長に達するまで続けられます。

まとめ

  • 最終的なトークンの選択は、Transformerのデコーダが生成した確率分布に基づきます。
  • デコーダが生成した確率分布に対して、選択手法(グリーディーデコーディング、ビームサーチ、トップKサンプリング、トップPサンプリング、温度スケーリングなど)を適用し、次のトークンが決まります。
  • 選択のプロセスは主に、確率分布に対する数学的操作と、選択手法に応じたアルゴリズムによって行われます。
  • 選ばれたトークンはシーケンスに追加され、次のトークン生成プロセスに反映されます。

このプロセスがシーケンス全体にわたって繰り返されることで、最終的な出力テキストが生成されます。