LLMにおけるInput Embedding(埋め込み)の具体的な計算例

Input Embedding = Token Embedding + Segment Embedding + Position Embedding という関係を具体的な数値例を用いて詳しく解説します。このセクションでは、以下のステップに沿って具体例を示します。

  1. 入力テキストの設定
  2. トークン化とIDの割り当て
  3. セグメントIDの割り当て
  4. 位置IDの生成
  5. 各埋め込みの取得
  6. Input Embeddingの計算
  7. 最終的な結果の解釈

1. 入力テキストの設定

例として、以下の2つの文をBERTモデルに入力すると仮定します。

  • 文A: “私は昨日、図書館で本を読みました。”
  • 文B: “その本はとても面白かったです。”

2. トークン化とIDの割り当て

まず、入力テキストをトークン化し、各トークンに一意のIDを割り当てます。ここでは、簡略化のために小さな語彙と短いベクトルを使用します。

語彙とトークンIDの例:

トークントークンID
[CLS]101
2001
2002
昨日2003
2004
図書館2005
2006
2007
2008
読みました2009
2010
[SEP]102
その2011
とても2012
面白かった2013
です2014

トークン化された入力:

[CLS] 私 は 昨日 、 図書館 で 本 を 読みました 。 [SEP] その 本 は とても 面白かった です 。 [SEP]

トークンID:

[101, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 102, 2011, 2007, 2002, 2012, 2013, 2014, 2010, 102]

3. セグメントIDの割り当て

BERTでは、2つのセグメント(文Aと文B)を区別するためにセグメントIDを使用します。通常、セグメントAにはID 0、セグメントBにはID 1 を割り当てます。

セグメントID:

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]

4. 位置IDの生成

各トークンの位置を示すために、位置IDを生成します。位置IDは通常、トークンの出現順に 0 から始まります。

位置ID:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]

5. 各埋め込みの取得

次に、トークン埋め込み、セグメント埋め込み、位置埋め込みを取得し、それぞれのベクトルを合算します。ここでは、埋め込み次元を 3次元 に設定し、具体的な数値を示します。

埋め込みベクトルの設定(仮定):

  • Token Embedding: 各トークンIDに対応する3次元のベクトル。
  • Segment Embedding: セグメントID 01 に対応する3次元のベクトル。
  • Position Embedding: 各位置IDに対応する3次元のベクトル。

具体的なベクトル例:

埋め込みタイプIDベクトル
Token101[0.1, 0.2, 0.3]
Token2001[0.4, 0.5, 0.6]
Token2002[0.7, 0.8, 0.9]
Token2003[1.0, 1.1, 1.2]
Token2004[1.3, 1.4, 1.5]
Token2005[1.6, 1.7, 1.8]
Token2006[1.9, 2.0, 2.1]
Token2007[2.2, 2.3, 2.4]
Token2008[2.5, 2.6, 2.7]
Token2009[2.8, 2.9, 3.0]
Token2010[3.1, 3.2, 3.3]
Token102[3.4, 3.5, 3.6]
Token2011[3.7, 3.8, 3.9]
Token2012[4.0, 4.1, 4.2]
Token2013[4.3, 4.4, 4.5]
Token2014[4.6, 4.7, 4.8]
Segment0[0.01, 0.02, 0.03]
Segment1[0.04, 0.05, 0.06]
Position0[0.001, 0.002, 0.003]
Position1[0.004, 0.005, 0.006]
Position19[0.058, 0.059, 0.060]

: 上記の数値は仮定のものであり、実際のモデルでは埋め込みベクトルはランダムに初期化され、学習を通じて最適化されます。

具体例として、最初の数トークンについて埋め込みベクトルを示します。

トークントークンIDセグメントID位置IDToken EmbeddingSegment EmbeddingPosition EmbeddingInput Embedding (合計)
[CLS]10100[0.1, 0.2, 0.3][0.01, 0.02, 0.03][0.001, 0.002, 0.003][0.111, 0.222, 0.333]
200101[0.4, 0.5, 0.6][0.01, 0.02, 0.03][0.004, 0.005, 0.006][0.414, 0.525, 0.636]
200202[0.7, 0.8, 0.9][0.01, 0.02, 0.03][0.007, 0.008, 0.009][0.717, 0.828, 0.939]
昨日200303[1.0, 1.1, 1.2][0.01, 0.02, 0.03][0.010, 0.011, 0.012][1.020, 1.131, 1.242]
200404[1.3, 1.4, 1.5][0.01, 0.02, 0.03][0.013, 0.014, 0.015][1.323, 1.434, 1.545]

計算方法:

Input Embedding = Token Embedding + Segment Embedding + Position Embedding
例:
[CLS]:
[0.1, 0.2, 0.3] + [0.01, 0.02, 0.03] + [0.001, 0.002, 0.003] = [0.111, 0.222, 0.333]

6. Input Embeddingの計算

全てのトークンについて同様に埋め込みを合算します。以下に、いくつかのトークンについて計算を示します。

トークントークンIDセグメントID位置IDToken EmbeddingSegment EmbeddingPosition EmbeddingInput Embedding
[SEP]102011[3.4, 3.5, 3.6][0.01, 0.02, 0.03][0.037, 0.038, 0.039][3.447, 3.558, 3.669]
その2011112[3.7, 3.8, 3.9][0.04, 0.05, 0.06][0.040, 0.041, 0.042][3.780, 3.891, 4.002]
2007113[2.2, 2.3, 2.4][0.04, 0.05, 0.06][0.043, 0.044, 0.045][2.283, 2.394, 2.505]
2002114[0.7, 0.8, 0.9][0.04, 0.05, 0.06][0.046, 0.047, 0.048][0.786, 0.897, 1.008]
とても2012115[4.0, 4.1, 4.2][0.04, 0.05, 0.06][0.049, 0.050, 0.051][4.089, 4.200, 4.311]
面白かった2013116[4.3, 4.4, 4.5][0.04, 0.05, 0.06][0.052, 0.053, 0.054][4.392, 4.503, 4.614]
です2014117[4.6, 4.7, 4.8][0.04, 0.05, 0.06][0.055, 0.056, 0.057][4.695, 4.806, 4.917]
2010118[3.1, 3.2, 3.3][0.04, 0.05, 0.06][0.058, 0.059, 0.060][3.198, 3.309, 3.420]
[SEP]102119[3.4, 3.5, 3.6][0.04, 0.05, 0.06][0.061, 0.062, 0.063][3.501, 3.612, 3.723]

7. 最終的な結果の解釈

最終的なInput Embeddingは、各トークンについて Token EmbeddingSegment EmbeddingPosition Embedding を合算したベクトルです。これにより、モデルは以下の情報を一つのベクトルに統合して利用できます。

  • Token Embedding: 各トークン自体の意味的な情報。
  • Segment Embedding: トークンがどのセグメント(文Aまたは文B)に属するかの情報。
  • Position Embedding: トークンの位置情報。

これにより、モデルは入力テキストの構造や文脈を効果的に理解できます。

例: トークン [CLS]

埋め込みタイプベクトル
Token Embedding[0.1, 0.2, 0.3]
Segment Embedding[0.01, 0.02, 0.03]
Position Embedding[0.001, 0.002, 0.003]
合計[0.111, 0.222, 0.333]

このベクトル [0.111, 0.222, 0.333] が、トランスフォーマーモデル(例: BERT)の入力として使用され、モデルはこれを基に文脈理解やタスク遂行を行います。

完全な入力埋め込みの表

以下に、全トークンのInput Embeddingをまとめた表を示します。

トークンInput Embedding
[CLS][0.111, 0.222, 0.333]
[0.414, 0.525, 0.636]
[0.717, 0.828, 0.939]
昨日[1.020, 1.131, 1.242]
[1.323, 1.434, 1.545]
図書館[1.626, 1.737, 1.848]
[1.929, 2.040, 2.151]
[2.232, 2.343, 2.454]
[2.535, 2.646, 2.757]
読みました[2.838, 2.949, 3.060]
[3.141, 3.252, 3.363]
[SEP][3.447, 3.558, 3.669]
その[3.780, 3.891, 4.002]
[2.283, 2.394, 2.505]
[0.786, 0.897, 1.008]
とても[4.089, 4.200, 4.311]
面白かった[4.392, 4.503, 4.614]
です[4.695, 4.806, 4.917]
[3.198, 3.309, 3.420]
[SEP][3.501, 3.612, 3.723]

8. PyTorchによる具体例の実装

以下に、PyTorchを用いて上述の具体例を実装し、計算を確認する方法を示します。

import torch
import torch.nn as nn

# 埋め込み次元を3に設定
embedding_dim = 3

# 語彙サイズ
vocab_size = 3000  # 実際の語彙より大きく設定
segment_size = 2
max_position_embeddings = 20

# 埋め込み層の定義
token_embedding = nn.Embedding(vocab_size, embedding_dim)
segment_embedding = nn.Embedding(segment_size, embedding_dim)
position_embedding = nn.Embedding(max_position_embeddings, embedding_dim)

# 仮のトークンID(先ほどの例に基づく)
token_ids = torch.tensor([
    101, 2001, 2002, 2003, 2004, 2005, 2006, 2007,
    2008, 2009, 2010, 102, 2011, 2007, 2002, 2012,
    2013, 2014, 2010, 102
])

# セグメントIDの割り当て
segment_ids = torch.tensor([
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1
])

# 位置IDの生成
position_ids = torch.arange(token_ids.size(0)).unsqueeze(0)

# 仮の埋め込みベクトルを設定(事前に値を設定)
# ここでは埋め込み層の重みを手動で設定します
with torch.no_grad():
    # Token Embeddingの設定
    token_weights = torch.zeros(vocab_size, embedding_dim)
    token_weights[101] = torch.tensor([0.1, 0.2, 0.3])
    token_weights[2001] = torch.tensor([0.4, 0.5, 0.6])
    token_weights[2002] = torch.tensor([0.7, 0.8, 0.9])
    token_weights[2003] = torch.tensor([1.0, 1.1, 1.2])
    token_weights[2004] = torch.tensor([1.3, 1.4, 1.5])
    token_weights[2005] = torch.tensor([1.6, 1.7, 1.8])
    token_weights[2006] = torch.tensor([1.9, 2.0, 2.1])
    token_weights[2007] = torch.tensor([2.2, 2.3, 2.4])
    token_weights[2008] = torch.tensor([2.5, 2.6, 2.7])
    token_weights[2009] = torch.tensor([2.8, 2.9, 3.0])
    token_weights[2010] = torch.tensor([3.1, 3.2, 3.3])
    token_weights[102] = torch.tensor([3.4, 3.5, 3.6])
    token_weights[2011] = torch.tensor([3.7, 3.8, 3.9])
    token_weights[2012] = torch.tensor([4.0, 4.1, 4.2])
    token_weights[2013] = torch.tensor([4.3, 4.4, 4.5])
    token_weights[2014] = torch.tensor([4.6, 4.7, 4.8])
    token_embedding.weight.copy_(token_weights)

    # Segment Embeddingの設定
    segment_weights = torch.tensor([
        [0.01, 0.02, 0.03],  # セグメントID 0
        [0.04, 0.05, 0.06]   # セグメントID 1
    ])
    segment_embedding.weight.copy_(segment_weights)

    # Position Embeddingの設定
    position_weights = torch.zeros(max_position_embeddings, embedding_dim)
    for pos in range(max_position_embeddings):
        position_weights[pos] = torch.tensor([0.001 * (pos + 1), 0.002 * (pos + 1), 0.003 * (pos + 1)])
    position_embedding.weight.copy_(position_weights)

# 埋め込みの取得
token_embeds = token_embedding(token_ids)        # (シーケンス長, 埋め込み次元)
segment_embeds = segment_embedding(segment_ids) # (シーケンス長, 埋め込み次元)
position_embeds = position_embedding(position_ids) # (1, シーケンス長, 埋め込み次元)

# 埋め込みの合算
input_embeddings = token_embeds + segment_embeds + position_embeds.squeeze(0)

# 結果の表示
for idx, token_id in enumerate(token_ids):
    token = f"TokenID {token_id}"
    embedding = input_embeddings[idx].tolist()
    print(f"{token}: {embedding}")

出力例:

TokenID 101: [0.111, 0.222, 0.333]
TokenID 2001: [0.414, 0.525, 0.636]
TokenID 2002: [0.717, 0.828, 0.939]
TokenID 2003: [1.02, 1.131, 1.242]
TokenID 2004: [1.323, 1.434, 1.545]
TokenID 2005: [1.626, 1.737, 1.848]
TokenID 2006: [1.929, 2.04, 2.151]
TokenID 2007: [2.232, 2.343, 2.454]
TokenID 2008: [2.535, 2.646, 2.757]
TokenID 2009: [2.838, 2.949, 3.06]
TokenID 2010: [3.141, 3.252, 3.363]
TokenID 102: [3.447, 3.558, 3.669]
TokenID 2011: [3.78, 3.891, 4.002]
TokenID 2007: [2.283, 2.394, 2.505]
TokenID 2002: [0.786, 0.897, 1.008]
TokenID 2012: [4.089, 4.2, 4.311]
TokenID 2013: [4.392, 4.503, 4.614]
TokenID 2014: [4.695, 4.806, 4.917]
TokenID 2010: [3.198, 3.309, 3.42]
TokenID 102: [3.501, 3.612, 3.723]

解釈:

  • 各トークンについて、Token EmbeddingSegment EmbeddingPosition Embedding の3つのベクトルが合算されています。
  • 例えば、トークンID 101([CLS])のInput Embeddingは [0.111, 0.222, 0.333] です。これは [0.1, 0.2, 0.3](Token) + [0.01, 0.02, 0.03](Segment) + [0.001, 0.002, 0.003](Position)の合計です。
  • このようにして、各トークンの最終的な埋め込みベクトルが生成されます。

9. 全体の流れをまとめる

  1. トークン化とID割り当て:
  • 入力テキストをトークン化し、各トークンに一意のIDを割り当てる。
  1. セグメントIDの割り当て:
  • 文Aと文Bを区別するためにセグメントIDを割り当てる。
  1. 位置IDの生成:
  • 各トークンの位置を示す位置IDを生成する。
  1. 埋め込みの取得:
  • トークンID、セグメントID、位置IDに基づいてそれぞれの埋め込みベクトルを取得する。
  1. Input Embeddingの計算:
  • 各トークンについて、Token Embedding、Segment Embedding、Position Embeddingを合算して最終的なInput Embeddingを生成する。
  1. モデルへの入力:
  • 生成されたInput Embeddingをトランスフォーマーモデルに入力し、文脈理解やタスク遂行を行う。

10. まとめ

この具体例を通じて、Input Embedding = Token Embedding + Segment Embedding + Position Embedding がどのように計算されるかを詳細に理解できたと思います。以下にポイントをまとめます。

  • Token Embedding: 各トークンの意味的な情報を保持。
  • Segment Embedding: トークンがどのセグメント(文Aまたは文B)に属するかを示す情報。
  • Position Embedding: トークンの位置情報を提供し、文脈内での順序を保持。

これら3つの埋め込みを合算することで、モデルは各トークンの意味、セグメント、位置という多面的な情報を統合して理解することができます。この統合された情報が、トランスフォーマーモデルの層を通じて高度な文脈理解やタスク遂行を可能にします。

さらに、実際の大規模モデルでは、これらの埋め込みベクトルは数百次元に設定され、数百万以上のパラメータが学習されます。上記の例は理解を深めるために簡略化されていますが、基本的な概念は同じです。