torchtextでGloVeを使う方法を紹介していきます。
GloVeは割と有名なこともあってあらかじめtorchtextの方で用意がされているのでそれを使っていきます。
torchtextであらかじめ用意された分散表現を使うには2stepが必要です。
- 分散表現を読み込む(辞書作成時)
nn.Embedding
にロードする
分散表現を読み込む(辞書作成時)
torchtextで辞書を作成する時は以下のようにすると思います。
TEXT.build_vocab(data)
この際に分散表現を使用して辞書を作成することができます。
これを行うとTEXT.vocab.vectors
が作成され、単語のインデックスに対応した分散表現を取得できるようになるのです。
やり方はbuild_vocab()
を行う際にvectors
引数に事前学習済みの分散表現を渡すだけ。
torchtextでは事前学習済みの分散表現としてGloVe, FastText, CharNGramが用意されています。
コードとしてはこのようになる。
from torchtext.vocab import GloVe
TEXT.build_vocab(data, vectors=GloVe())
ちなみにtorchtext.vocab.GloVe
のソースコードを見てみるとダウンロード先がいくつかあってname
引数で指定ができるみたい。
これでTEXT.vocab.vectors
が作成され、GloVeの分散表現が獲得できるようになった。
試しに取得してみる。
300次元の分散表現なので少々長いですが確かに取得することができているようです。
ちなみに単語はengineerらしい。
print(TEXT.vocab.itos[16])
print(TEXT.vocab.vectors[16])
# engineer
# tensor([ 8.6023e-03, -1.8473e-01, 1.0752e-01, 1.2761e-02, 2.3580e-02,
# 5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01, 2.1522e+00,
# 2.3439e-01, -1.6664e-01, -5.5293e-01, 3.4471e-01, -8.7711e-02,
# 9.1632e-02, 3.5482e-01, 1.0917e+00, 1.3321e-02, 2.3328e-02,
# 9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
# -1.0330e-02, 2.5634e-01, 6.8111e-02, 8.1666e-01, -7.5717e-02,
# -5.1519e-04, 3.0003e-02, 3.1652e-01, 2.1723e-01, 1.4875e-01,
# 1.1622e-01, 2.3581e-01, 1.6913e-01, 1.9165e-01, -4.3172e-01,
# 5.6027e-01, -3.7024e-01, 2.6345e-01, 1.7778e-01, -3.0492e-01,
# -2.9777e-01, 5.6449e-02, -2.0469e-01, 4.6642e-01, 3.6060e-01,
# 1.0189e+00, -3.1800e-01, 9.2015e-02, 2.2912e-01, -3.7015e-01,
# 7.3733e-02, 3.0256e-01, -1.9641e-01, 6.2234e-01, -4.2259e-01,
# -1.4533e-01, -6.5040e-01, 4.9988e-01, 6.8856e-01, 7.8618e-01,
# -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01, 3.0818e-01,
# -7.1273e-01, 1.8886e-01, -5.5873e-01, 1.6748e-02, -9.9578e-03,
# -9.2847e-02, -1.5399e-01, 2.4189e-01, 1.6602e-01, 1.4182e-01,
# 2.5637e-01, -3.0831e-01, 1.9756e-01, -3.3636e-03, -5.6348e-02,
# -4.1302e-01, -1.9430e-01, 2.2908e-01, -3.0950e-01, -1.1492e-01,
# -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03, 3.6427e-01,
# -8.9268e-01, 4.0152e-01, -3.4483e-02, -2.6120e-01, 1.7357e-01,
# -5.8714e-01, 1.2500e-01, 3.3287e-01, 5.3705e-01, -7.2950e-02,
# -9.4287e-01, -2.4366e-02, 3.5483e-01, 3.3880e-01, -4.0253e-03,
# -2.4033e-01, 8.6481e-02, 1.5558e-01, 6.3489e-01, 3.5502e-01,
# -7.8141e-03, 4.7522e-01, 5.5590e-02, -1.2406e-01, -3.7251e-01,
# 5.8332e-01, -1.7430e-01, 7.6183e-02, -4.8598e-01, 5.6179e-01,
# 4.6588e-01, 1.0647e-01, 7.9589e-01, -1.3691e-01, 1.6181e-02,
# -1.8736e-01, 4.2970e-01, -3.2155e-01, -2.5580e-01, 2.3150e-01,
# -8.8382e-02, 3.3640e-02, 4.3296e-02, 2.5957e-01, -3.7414e-02,
# -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01, 3.4002e-02,
# 2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
# 2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03, 5.8724e-01,
# 2.7761e-01, -9.8296e-02, 5.0573e-02, 4.6795e-01, -2.3215e-01,
# 4.3205e-01, 8.9082e-02, 1.0690e-01, -2.1431e-01, 2.5087e-01,
# -6.0077e-01, 5.9497e-02, -3.5759e-01, 3.9667e-01, 4.4011e-01,
# -7.5568e-01, 1.5246e-01, -5.5170e-01, -6.0951e-01, 2.8309e-02,
# -2.9554e-01, -2.3058e-01, 3.4464e-01, 7.2006e-02, 2.2245e-01,
# 4.6076e-01, 5.0022e-02, 1.2248e-01, -9.7854e-02, -3.9058e-01,
# 7.7901e-02, 2.4518e-01, -8.1385e-02, 9.1565e-02, -1.5852e-02,
# -4.5259e-02, 2.2278e-02, -2.7536e-01, 4.4392e-01, 3.2147e-01,
# 2.6685e-01, 3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
# -3.7925e-01, 9.3671e-02, 5.5047e-02, -1.8612e-01, -5.2965e-01,
# -5.2107e-02, 2.3737e-01, -3.2585e-01, 6.2031e-01, -6.1760e-03,
# -1.7091e-01, -1.8027e-01, -7.7878e-03, 2.6468e-01, 2.0330e-01,
# 4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01, 6.7153e-02,
# 1.3823e-01, 2.5221e-01, 8.0991e-01, 5.8619e-01, 6.7332e-01,
# -4.1664e-01, 7.0408e-02, -5.0332e-02, -5.0495e-01, 1.0251e-01,
# -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
# -3.5694e-01, -2.0038e-01, 3.5725e-01, 2.9845e-01, 3.7109e-01,
# 5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
# -2.4221e-01, 1.4885e-02, 1.6681e-01, 3.4767e-02, -3.9645e-01,
# 1.1245e-01, 4.5854e-01, 1.7045e-01, -2.5139e-01, 4.7993e-01,
# -1.2335e+00, 4.4126e-02, -1.3951e-01, -1.3290e-01, 2.0486e-02,
# 4.3671e-01, 9.2521e-01, 9.6463e-02, -3.1544e-01, -4.9281e-01,
# -3.5787e-01, 1.0986e-01, -2.2779e-01, 3.5474e-01, 5.4375e-01,
# 1.3176e-01, 8.4382e-02, 3.0393e-01, 6.5258e-01, -5.4010e-01,
# 5.3910e-03, 2.6530e-01, 1.5273e-01, 1.8205e-01, -4.4993e-01,
# -5.1126e-02, -1.4226e-01, -1.4339e-01, 4.6857e-01, -4.8341e-01,
# -5.1560e-01, 1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
# 7.8592e-02, 2.2751e-01, 2.7308e-01, 4.1631e-01, 1.4083e-01,
# -3.5819e-01, 2.2986e-01, 7.9156e-01, -3.5246e-01, -8.5015e-02])
nn.Embeddingにロードして使う
実際に分散表現を使っていくには埋め込み層であるnn.Embedding
を使うので先ほど作った辞書の分散表現をnn.Embedding
が使えるようにする必要がある。
事前学習済みの分散表現を使う場合は以下のようにしてnn.Embedding
を使う。
import torch.nn as nn
embedding = nn.Embedding.from_pretrained(
embeddings=TEXT.vocab.vectors, freeze=True
# freeze=True であれば学習中に更新がされない(デフォルト:True)
)
試しに使ってみる。
まずはnn.Embedding
でなく直接取得する。
(上述のと同じ)
単語は同じくengineer
print(TEXT.vocab.itos[16])
print(TEXT.vocab.vectors[16])
# engineer
# tensor([ 8.6023e-03, -1.8473e-01, 1.0752e-01, 1.2761e-02, 2.3580e-02,
# 5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01, 2.1522e+00,
# 2.3439e-01, -1.6664e-01, -5.5293e-01, 3.4471e-01, -8.7711e-02,
# 9.1632e-02, 3.5482e-01, 1.0917e+00, 1.3321e-02, 2.3328e-02,
# 9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
# -1.0330e-02, 2.5634e-01, 6.8111e-02, 8.1666e-01, -7.5717e-02,
# -5.1519e-04, 3.0003e-02, 3.1652e-01, 2.1723e-01, 1.4875e-01,
# 1.1622e-01, 2.3581e-01, 1.6913e-01, 1.9165e-01, -4.3172e-01,
# 5.6027e-01, -3.7024e-01, 2.6345e-01, 1.7778e-01, -3.0492e-01,
# -2.9777e-01, 5.6449e-02, -2.0469e-01, 4.6642e-01, 3.6060e-01,
# 1.0189e+00, -3.1800e-01, 9.2015e-02, 2.2912e-01, -3.7015e-01,
# 7.3733e-02, 3.0256e-01, -1.9641e-01, 6.2234e-01, -4.2259e-01,
# -1.4533e-01, -6.5040e-01, 4.9988e-01, 6.8856e-01, 7.8618e-01,
# -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01, 3.0818e-01,
# -7.1273e-01, 1.8886e-01, -5.5873e-01, 1.6748e-02, -9.9578e-03,
# -9.2847e-02, -1.5399e-01, 2.4189e-01, 1.6602e-01, 1.4182e-01,
# 2.5637e-01, -3.0831e-01, 1.9756e-01, -3.3636e-03, -5.6348e-02,
# -4.1302e-01, -1.9430e-01, 2.2908e-01, -3.0950e-01, -1.1492e-01,
# -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03, 3.6427e-01,
# -8.9268e-01, 4.0152e-01, -3.4483e-02, -2.6120e-01, 1.7357e-01,
# -5.8714e-01, 1.2500e-01, 3.3287e-01, 5.3705e-01, -7.2950e-02,
# -9.4287e-01, -2.4366e-02, 3.5483e-01, 3.3880e-01, -4.0253e-03,
# -2.4033e-01, 8.6481e-02, 1.5558e-01, 6.3489e-01, 3.5502e-01,
# -7.8141e-03, 4.7522e-01, 5.5590e-02, -1.2406e-01, -3.7251e-01,
# 5.8332e-01, -1.7430e-01, 7.6183e-02, -4.8598e-01, 5.6179e-01,
# 4.6588e-01, 1.0647e-01, 7.9589e-01, -1.3691e-01, 1.6181e-02,
# -1.8736e-01, 4.2970e-01, -3.2155e-01, -2.5580e-01, 2.3150e-01,
# -8.8382e-02, 3.3640e-02, 4.3296e-02, 2.5957e-01, -3.7414e-02,
# -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01, 3.4002e-02,
# 2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
# 2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03, 5.8724e-01,
# 2.7761e-01, -9.8296e-02, 5.0573e-02, 4.6795e-01, -2.3215e-01,
# 4.3205e-01, 8.9082e-02, 1.0690e-01, -2.1431e-01, 2.5087e-01,
# -6.0077e-01, 5.9497e-02, -3.5759e-01, 3.9667e-01, 4.4011e-01,
# -7.5568e-01, 1.5246e-01, -5.5170e-01, -6.0951e-01, 2.8309e-02,
# -2.9554e-01, -2.3058e-01, 3.4464e-01, 7.2006e-02, 2.2245e-01,
# 4.6076e-01, 5.0022e-02, 1.2248e-01, -9.7854e-02, -3.9058e-01,
# 7.7901e-02, 2.4518e-01, -8.1385e-02, 9.1565e-02, -1.5852e-02,
# -4.5259e-02, 2.2278e-02, -2.7536e-01, 4.4392e-01, 3.2147e-01,
# 2.6685e-01, 3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
# -3.7925e-01, 9.3671e-02, 5.5047e-02, -1.8612e-01, -5.2965e-01,
# -5.2107e-02, 2.3737e-01, -3.2585e-01, 6.2031e-01, -6.1760e-03,
# -1.7091e-01, -1.8027e-01, -7.7878e-03, 2.6468e-01, 2.0330e-01,
# 4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01, 6.7153e-02,
# 1.3823e-01, 2.5221e-01, 8.0991e-01, 5.8619e-01, 6.7332e-01,
# -4.1664e-01, 7.0408e-02, -5.0332e-02, -5.0495e-01, 1.0251e-01,
# -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
# -3.5694e-01, -2.0038e-01, 3.5725e-01, 2.9845e-01, 3.7109e-01,
# 5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
# -2.4221e-01, 1.4885e-02, 1.6681e-01, 3.4767e-02, -3.9645e-01,
# 1.1245e-01, 4.5854e-01, 1.7045e-01, -2.5139e-01, 4.7993e-01,
# -1.2335e+00, 4.4126e-02, -1.3951e-01, -1.3290e-01, 2.0486e-02,
# 4.3671e-01, 9.2521e-01, 9.6463e-02, -3.1544e-01, -4.9281e-01,
# -3.5787e-01, 1.0986e-01, -2.2779e-01, 3.5474e-01, 5.4375e-01,
# 1.3176e-01, 8.4382e-02, 3.0393e-01, 6.5258e-01, -5.4010e-01,
# 5.3910e-03, 2.6530e-01, 1.5273e-01, 1.8205e-01, -4.4993e-01,
# -5.1126e-02, -1.4226e-01, -1.4339e-01, 4.6857e-01, -4.8341e-01,
# -5.1560e-01, 1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
# 7.8592e-02, 2.2751e-01, 2.7308e-01, 4.1631e-01, 1.4083e-01,
# -3.5819e-01, 2.2986e-01, 7.9156e-01, -3.5246e-01, -8.5015e-02])
次にnn.Embedding
を使って取得してみる。
embedding = nn.Embedding.from_pretrained(
embeddings=TEXT.vocab.vectors, freeze=True
)
embedding(torch.tensor([16]))
# tensor([[ 8.6023e-03, -1.8473e-01, 1.0752e-01, 1.2761e-02, 2.3580e-02,
# 5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01, 2.1522e+00,
# 2.3439e-01, -1.6664e-01, -5.5293e-01, 3.4471e-01, -8.7711e-02,
# 9.1632e-02, 3.5482e-01, 1.0917e+00, 1.3321e-02, 2.3328e-02,
# 9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
# -1.0330e-02, 2.5634e-01, 6.8111e-02, 8.1666e-01, -7.5717e-02,
# -5.1519e-04, 3.0003e-02, 3.1652e-01, 2.1723e-01, 1.4875e-01,
# 1.1622e-01, 2.3581e-01, 1.6913e-01, 1.9165e-01, -4.3172e-01,
# 5.6027e-01, -3.7024e-01, 2.6345e-01, 1.7778e-01, -3.0492e-01,
# -2.9777e-01, 5.6449e-02, -2.0469e-01, 4.6642e-01, 3.6060e-01,
# 1.0189e+00, -3.1800e-01, 9.2015e-02, 2.2912e-01, -3.7015e-01,
# 7.3733e-02, 3.0256e-01, -1.9641e-01, 6.2234e-01, -4.2259e-01,
# -1.4533e-01, -6.5040e-01, 4.9988e-01, 6.8856e-01, 7.8618e-01,
# -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01, 3.0818e-01,
# -7.1273e-01, 1.8886e-01, -5.5873e-01, 1.6748e-02, -9.9578e-03,
# -9.2847e-02, -1.5399e-01, 2.4189e-01, 1.6602e-01, 1.4182e-01,
# 2.5637e-01, -3.0831e-01, 1.9756e-01, -3.3636e-03, -5.6348e-02,
# -4.1302e-01, -1.9430e-01, 2.2908e-01, -3.0950e-01, -1.1492e-01,
# -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03, 3.6427e-01,
# -8.9268e-01, 4.0152e-01, -3.4483e-02, -2.6120e-01, 1.7357e-01,
# -5.8714e-01, 1.2500e-01, 3.3287e-01, 5.3705e-01, -7.2950e-02,
# -9.4287e-01, -2.4366e-02, 3.5483e-01, 3.3880e-01, -4.0253e-03,
# -2.4033e-01, 8.6481e-02, 1.5558e-01, 6.3489e-01, 3.5502e-01,
# -7.8141e-03, 4.7522e-01, 5.5590e-02, -1.2406e-01, -3.7251e-01,
# 5.8332e-01, -1.7430e-01, 7.6183e-02, -4.8598e-01, 5.6179e-01,
# 4.6588e-01, 1.0647e-01, 7.9589e-01, -1.3691e-01, 1.6181e-02,
# -1.8736e-01, 4.2970e-01, -3.2155e-01, -2.5580e-01, 2.3150e-01,
# -8.8382e-02, 3.3640e-02, 4.3296e-02, 2.5957e-01, -3.7414e-02,
# -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01, 3.4002e-02,
# 2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
# 2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03, 5.8724e-01,
# 2.7761e-01, -9.8296e-02, 5.0573e-02, 4.6795e-01, -2.3215e-01,
# 4.3205e-01, 8.9082e-02, 1.0690e-01, -2.1431e-01, 2.5087e-01,
# -6.0077e-01, 5.9497e-02, -3.5759e-01, 3.9667e-01, 4.4011e-01,
# -7.5568e-01, 1.5246e-01, -5.5170e-01, -6.0951e-01, 2.8309e-02,
# -2.9554e-01, -2.3058e-01, 3.4464e-01, 7.2006e-02, 2.2245e-01,
# 4.6076e-01, 5.0022e-02, 1.2248e-01, -9.7854e-02, -3.9058e-01,
# 7.7901e-02, 2.4518e-01, -8.1385e-02, 9.1565e-02, -1.5852e-02,
# -4.5259e-02, 2.2278e-02, -2.7536e-01, 4.4392e-01, 3.2147e-01,
# 2.6685e-01, 3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
# -3.7925e-01, 9.3671e-02, 5.5047e-02, -1.8612e-01, -5.2965e-01,
# -5.2107e-02, 2.3737e-01, -3.2585e-01, 6.2031e-01, -6.1760e-03,
# -1.7091e-01, -1.8027e-01, -7.7878e-03, 2.6468e-01, 2.0330e-01,
# 4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01, 6.7153e-02,
# 1.3823e-01, 2.5221e-01, 8.0991e-01, 5.8619e-01, 6.7332e-01,
# -4.1664e-01, 7.0408e-02, -5.0332e-02, -5.0495e-01, 1.0251e-01,
# -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
# -3.5694e-01, -2.0038e-01, 3.5725e-01, 2.9845e-01, 3.7109e-01,
# 5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
# -2.4221e-01, 1.4885e-02, 1.6681e-01, 3.4767e-02, -3.9645e-01,
# 1.1245e-01, 4.5854e-01, 1.7045e-01, -2.5139e-01, 4.7993e-01,
# -1.2335e+00, 4.4126e-02, -1.3951e-01, -1.3290e-01, 2.0486e-02,
# 4.3671e-01, 9.2521e-01, 9.6463e-02, -3.1544e-01, -4.9281e-01,
# -3.5787e-01, 1.0986e-01, -2.2779e-01, 3.5474e-01, 5.4375e-01,
# 1.3176e-01, 8.4382e-02, 3.0393e-01, 6.5258e-01, -5.4010e-01,
# 5.3910e-03, 2.6530e-01, 1.5273e-01, 1.8205e-01, -4.4993e-01,
# -5.1126e-02, -1.4226e-01, -1.4339e-01, 4.6857e-01, -4.8341e-01,
# -5.1560e-01, 1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
# 7.8592e-02, 2.2751e-01, 2.7308e-01, 4.1631e-01, 1.4083e-01,
# -3.5819e-01, 2.2986e-01, 7.9156e-01, -3.5246e-01, -8.5015e-02]])
nn.Embedding
を使っても正しく分散表現を取得できてる!
あとはこんな感じにモデルに組み込んでいくだけ。
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding.from_pretrained(
embeddings=TEXT.vocab.vectors,
freeze=True
)
self.linear = nn.Linear(300, 10)
def forward(self, x):
x = self.embedding(x)
x = self.linear(x)
return x
model = Model()
参考
コメント