torchtextでGloVeを使ってみる

gloveの使い方の記事のアイキャッチPython

torchtextでGloVeを使う方法を紹介していきます。

GloVeは割と有名なこともあってあらかじめtorchtextの方で用意がされているのでそれを使っていきます。

torchtextであらかじめ用意された分散表現を使うには2stepが必要です。

  1. 分散表現を読み込む(辞書作成時)
  2. nn.Embeddingにロードする

分散表現を読み込む(辞書作成時)

torchtextで辞書を作成する時は以下のようにすると思います。

TEXT.build_vocab(data)

この際に分散表現を使用して辞書を作成することができます。
これを行うとTEXT.vocab.vectorsが作成され、単語のインデックスに対応した分散表現を取得できるようになるのです。

やり方はbuild_vocab()を行う際にvectors引数に事前学習済みの分散表現を渡すだけ。
torchtextでは事前学習済みの分散表現としてGloVe, FastText, CharNGramが用意されています。

コードとしてはこのようになる。

from torchtext.vocab import GloVe

TEXT.build_vocab(data, vectors=GloVe())

ちなみにtorchtext.vocab.GloVeのソースコードを見てみるとダウンロード先がいくつかあってname引数で指定ができるみたい。

torchtext.vocab — torchtext 0.10.0 documentation

これでTEXT.vocab.vectorsが作成され、GloVeの分散表現が獲得できるようになった。
試しに取得してみる。

300次元の分散表現なので少々長いですが確かに取得することができているようです。
ちなみに単語はengineerらしい。

print(TEXT.vocab.itos[16])
print(TEXT.vocab.vectors[16])
# engineer
# tensor([ 8.6023e-03, -1.8473e-01,  1.0752e-01,  1.2761e-02,  2.3580e-02,
#          5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01,  2.1522e+00,
#          2.3439e-01, -1.6664e-01, -5.5293e-01,  3.4471e-01, -8.7711e-02,
#          9.1632e-02,  3.5482e-01,  1.0917e+00,  1.3321e-02,  2.3328e-02,
#          9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
#         -1.0330e-02,  2.5634e-01,  6.8111e-02,  8.1666e-01, -7.5717e-02,
#         -5.1519e-04,  3.0003e-02,  3.1652e-01,  2.1723e-01,  1.4875e-01,
#          1.1622e-01,  2.3581e-01,  1.6913e-01,  1.9165e-01, -4.3172e-01,
#          5.6027e-01, -3.7024e-01,  2.6345e-01,  1.7778e-01, -3.0492e-01,
#         -2.9777e-01,  5.6449e-02, -2.0469e-01,  4.6642e-01,  3.6060e-01,
#          1.0189e+00, -3.1800e-01,  9.2015e-02,  2.2912e-01, -3.7015e-01,
#          7.3733e-02,  3.0256e-01, -1.9641e-01,  6.2234e-01, -4.2259e-01,
#         -1.4533e-01, -6.5040e-01,  4.9988e-01,  6.8856e-01,  7.8618e-01,
#         -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01,  3.0818e-01,
#         -7.1273e-01,  1.8886e-01, -5.5873e-01,  1.6748e-02, -9.9578e-03,
#         -9.2847e-02, -1.5399e-01,  2.4189e-01,  1.6602e-01,  1.4182e-01,
#          2.5637e-01, -3.0831e-01,  1.9756e-01, -3.3636e-03, -5.6348e-02,
#         -4.1302e-01, -1.9430e-01,  2.2908e-01, -3.0950e-01, -1.1492e-01,
#         -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03,  3.6427e-01,
#         -8.9268e-01,  4.0152e-01, -3.4483e-02, -2.6120e-01,  1.7357e-01,
#         -5.8714e-01,  1.2500e-01,  3.3287e-01,  5.3705e-01, -7.2950e-02,
#         -9.4287e-01, -2.4366e-02,  3.5483e-01,  3.3880e-01, -4.0253e-03,
#         -2.4033e-01,  8.6481e-02,  1.5558e-01,  6.3489e-01,  3.5502e-01,
#         -7.8141e-03,  4.7522e-01,  5.5590e-02, -1.2406e-01, -3.7251e-01,
#          5.8332e-01, -1.7430e-01,  7.6183e-02, -4.8598e-01,  5.6179e-01,
#          4.6588e-01,  1.0647e-01,  7.9589e-01, -1.3691e-01,  1.6181e-02,
#         -1.8736e-01,  4.2970e-01, -3.2155e-01, -2.5580e-01,  2.3150e-01,
#         -8.8382e-02,  3.3640e-02,  4.3296e-02,  2.5957e-01, -3.7414e-02,
#         -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01,  3.4002e-02,
#          2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
#          2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03,  5.8724e-01,
#          2.7761e-01, -9.8296e-02,  5.0573e-02,  4.6795e-01, -2.3215e-01,
#          4.3205e-01,  8.9082e-02,  1.0690e-01, -2.1431e-01,  2.5087e-01,
#         -6.0077e-01,  5.9497e-02, -3.5759e-01,  3.9667e-01,  4.4011e-01,
#         -7.5568e-01,  1.5246e-01, -5.5170e-01, -6.0951e-01,  2.8309e-02,
#         -2.9554e-01, -2.3058e-01,  3.4464e-01,  7.2006e-02,  2.2245e-01,
#          4.6076e-01,  5.0022e-02,  1.2248e-01, -9.7854e-02, -3.9058e-01,
#          7.7901e-02,  2.4518e-01, -8.1385e-02,  9.1565e-02, -1.5852e-02,
#         -4.5259e-02,  2.2278e-02, -2.7536e-01,  4.4392e-01,  3.2147e-01,
#          2.6685e-01,  3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
#         -3.7925e-01,  9.3671e-02,  5.5047e-02, -1.8612e-01, -5.2965e-01,
#         -5.2107e-02,  2.3737e-01, -3.2585e-01,  6.2031e-01, -6.1760e-03,
#         -1.7091e-01, -1.8027e-01, -7.7878e-03,  2.6468e-01,  2.0330e-01,
#          4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01,  6.7153e-02,
#          1.3823e-01,  2.5221e-01,  8.0991e-01,  5.8619e-01,  6.7332e-01,
#         -4.1664e-01,  7.0408e-02, -5.0332e-02, -5.0495e-01,  1.0251e-01,
#         -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
#         -3.5694e-01, -2.0038e-01,  3.5725e-01,  2.9845e-01,  3.7109e-01,
#          5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
#         -2.4221e-01,  1.4885e-02,  1.6681e-01,  3.4767e-02, -3.9645e-01,
#          1.1245e-01,  4.5854e-01,  1.7045e-01, -2.5139e-01,  4.7993e-01,
#         -1.2335e+00,  4.4126e-02, -1.3951e-01, -1.3290e-01,  2.0486e-02,
#          4.3671e-01,  9.2521e-01,  9.6463e-02, -3.1544e-01, -4.9281e-01,
#         -3.5787e-01,  1.0986e-01, -2.2779e-01,  3.5474e-01,  5.4375e-01,
#          1.3176e-01,  8.4382e-02,  3.0393e-01,  6.5258e-01, -5.4010e-01,
#          5.3910e-03,  2.6530e-01,  1.5273e-01,  1.8205e-01, -4.4993e-01,
#         -5.1126e-02, -1.4226e-01, -1.4339e-01,  4.6857e-01, -4.8341e-01,
#         -5.1560e-01,  1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
#          7.8592e-02,  2.2751e-01,  2.7308e-01,  4.1631e-01,  1.4083e-01,
#         -3.5819e-01,  2.2986e-01,  7.9156e-01, -3.5246e-01, -8.5015e-02])

nn.Embeddingにロードして使う

実際に分散表現を使っていくには埋め込み層であるnn.Embeddingを使うので先ほど作った辞書の分散表現をnn.Embeddingが使えるようにする必要がある。

事前学習済みの分散表現を使う場合は以下のようにしてnn.Embeddingを使う。

import torch.nn as nn

embedding = nn.Embedding.from_pretrained(
    embeddings=TEXT.vocab.vectors, freeze=True
    # freeze=True であれば学習中に更新がされない(デフォルト:True)
)

試しに使ってみる。
まずはnn.Embeddingでなく直接取得する。
(上述のと同じ)

単語は同じくengineer

print(TEXT.vocab.itos[16])
print(TEXT.vocab.vectors[16])
# engineer
# tensor([ 8.6023e-03, -1.8473e-01,  1.0752e-01,  1.2761e-02,  2.3580e-02,
#          5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01,  2.1522e+00,
#          2.3439e-01, -1.6664e-01, -5.5293e-01,  3.4471e-01, -8.7711e-02,
#          9.1632e-02,  3.5482e-01,  1.0917e+00,  1.3321e-02,  2.3328e-02,
#          9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
#         -1.0330e-02,  2.5634e-01,  6.8111e-02,  8.1666e-01, -7.5717e-02,
#         -5.1519e-04,  3.0003e-02,  3.1652e-01,  2.1723e-01,  1.4875e-01,
#          1.1622e-01,  2.3581e-01,  1.6913e-01,  1.9165e-01, -4.3172e-01,
#          5.6027e-01, -3.7024e-01,  2.6345e-01,  1.7778e-01, -3.0492e-01,
#         -2.9777e-01,  5.6449e-02, -2.0469e-01,  4.6642e-01,  3.6060e-01,
#          1.0189e+00, -3.1800e-01,  9.2015e-02,  2.2912e-01, -3.7015e-01,
#          7.3733e-02,  3.0256e-01, -1.9641e-01,  6.2234e-01, -4.2259e-01,
#         -1.4533e-01, -6.5040e-01,  4.9988e-01,  6.8856e-01,  7.8618e-01,
#         -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01,  3.0818e-01,
#         -7.1273e-01,  1.8886e-01, -5.5873e-01,  1.6748e-02, -9.9578e-03,
#         -9.2847e-02, -1.5399e-01,  2.4189e-01,  1.6602e-01,  1.4182e-01,
#          2.5637e-01, -3.0831e-01,  1.9756e-01, -3.3636e-03, -5.6348e-02,
#         -4.1302e-01, -1.9430e-01,  2.2908e-01, -3.0950e-01, -1.1492e-01,
#         -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03,  3.6427e-01,
#         -8.9268e-01,  4.0152e-01, -3.4483e-02, -2.6120e-01,  1.7357e-01,
#         -5.8714e-01,  1.2500e-01,  3.3287e-01,  5.3705e-01, -7.2950e-02,
#         -9.4287e-01, -2.4366e-02,  3.5483e-01,  3.3880e-01, -4.0253e-03,
#         -2.4033e-01,  8.6481e-02,  1.5558e-01,  6.3489e-01,  3.5502e-01,
#         -7.8141e-03,  4.7522e-01,  5.5590e-02, -1.2406e-01, -3.7251e-01,
#          5.8332e-01, -1.7430e-01,  7.6183e-02, -4.8598e-01,  5.6179e-01,
#          4.6588e-01,  1.0647e-01,  7.9589e-01, -1.3691e-01,  1.6181e-02,
#         -1.8736e-01,  4.2970e-01, -3.2155e-01, -2.5580e-01,  2.3150e-01,
#         -8.8382e-02,  3.3640e-02,  4.3296e-02,  2.5957e-01, -3.7414e-02,
#         -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01,  3.4002e-02,
#          2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
#          2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03,  5.8724e-01,
#          2.7761e-01, -9.8296e-02,  5.0573e-02,  4.6795e-01, -2.3215e-01,
#          4.3205e-01,  8.9082e-02,  1.0690e-01, -2.1431e-01,  2.5087e-01,
#         -6.0077e-01,  5.9497e-02, -3.5759e-01,  3.9667e-01,  4.4011e-01,
#         -7.5568e-01,  1.5246e-01, -5.5170e-01, -6.0951e-01,  2.8309e-02,
#         -2.9554e-01, -2.3058e-01,  3.4464e-01,  7.2006e-02,  2.2245e-01,
#          4.6076e-01,  5.0022e-02,  1.2248e-01, -9.7854e-02, -3.9058e-01,
#          7.7901e-02,  2.4518e-01, -8.1385e-02,  9.1565e-02, -1.5852e-02,
#         -4.5259e-02,  2.2278e-02, -2.7536e-01,  4.4392e-01,  3.2147e-01,
#          2.6685e-01,  3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
#         -3.7925e-01,  9.3671e-02,  5.5047e-02, -1.8612e-01, -5.2965e-01,
#         -5.2107e-02,  2.3737e-01, -3.2585e-01,  6.2031e-01, -6.1760e-03,
#         -1.7091e-01, -1.8027e-01, -7.7878e-03,  2.6468e-01,  2.0330e-01,
#          4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01,  6.7153e-02,
#          1.3823e-01,  2.5221e-01,  8.0991e-01,  5.8619e-01,  6.7332e-01,
#         -4.1664e-01,  7.0408e-02, -5.0332e-02, -5.0495e-01,  1.0251e-01,
#         -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
#         -3.5694e-01, -2.0038e-01,  3.5725e-01,  2.9845e-01,  3.7109e-01,
#          5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
#         -2.4221e-01,  1.4885e-02,  1.6681e-01,  3.4767e-02, -3.9645e-01,
#          1.1245e-01,  4.5854e-01,  1.7045e-01, -2.5139e-01,  4.7993e-01,
#         -1.2335e+00,  4.4126e-02, -1.3951e-01, -1.3290e-01,  2.0486e-02,
#          4.3671e-01,  9.2521e-01,  9.6463e-02, -3.1544e-01, -4.9281e-01,
#         -3.5787e-01,  1.0986e-01, -2.2779e-01,  3.5474e-01,  5.4375e-01,
#          1.3176e-01,  8.4382e-02,  3.0393e-01,  6.5258e-01, -5.4010e-01,
#          5.3910e-03,  2.6530e-01,  1.5273e-01,  1.8205e-01, -4.4993e-01,
#         -5.1126e-02, -1.4226e-01, -1.4339e-01,  4.6857e-01, -4.8341e-01,
#         -5.1560e-01,  1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
#          7.8592e-02,  2.2751e-01,  2.7308e-01,  4.1631e-01,  1.4083e-01,
#         -3.5819e-01,  2.2986e-01,  7.9156e-01, -3.5246e-01, -8.5015e-02])

次にnn.Embeddingを使って取得してみる。

embedding = nn.Embedding.from_pretrained(
    embeddings=TEXT.vocab.vectors, freeze=True
)

embedding(torch.tensor([16]))
# tensor([[ 8.6023e-03, -1.8473e-01,  1.0752e-01,  1.2761e-02,  2.3580e-02,
#           5.0082e-01, -1.7082e-01, -6.6566e-01, -5.5796e-01,  2.1522e+00,
#           2.3439e-01, -1.6664e-01, -5.5293e-01,  3.4471e-01, -8.7711e-02,
#           9.1632e-02,  3.5482e-01,  1.0917e+00,  1.3321e-02,  2.3328e-02,
#           9.3148e-01, -4.8274e-01, -5.0777e-01, -5.1008e-01, -1.1556e-01,
#          -1.0330e-02,  2.5634e-01,  6.8111e-02,  8.1666e-01, -7.5717e-02,
#          -5.1519e-04,  3.0003e-02,  3.1652e-01,  2.1723e-01,  1.4875e-01,
#           1.1622e-01,  2.3581e-01,  1.6913e-01,  1.9165e-01, -4.3172e-01,
#           5.6027e-01, -3.7024e-01,  2.6345e-01,  1.7778e-01, -3.0492e-01,
#          -2.9777e-01,  5.6449e-02, -2.0469e-01,  4.6642e-01,  3.6060e-01,
#           1.0189e+00, -3.1800e-01,  9.2015e-02,  2.2912e-01, -3.7015e-01,
#           7.3733e-02,  3.0256e-01, -1.9641e-01,  6.2234e-01, -4.2259e-01,
#          -1.4533e-01, -6.5040e-01,  4.9988e-01,  6.8856e-01,  7.8618e-01,
#          -7.4827e-02, -1.1033e-01, -3.0656e-01, -6.4927e-01,  3.0818e-01,
#          -7.1273e-01,  1.8886e-01, -5.5873e-01,  1.6748e-02, -9.9578e-03,
#          -9.2847e-02, -1.5399e-01,  2.4189e-01,  1.6602e-01,  1.4182e-01,
#           2.5637e-01, -3.0831e-01,  1.9756e-01, -3.3636e-03, -5.6348e-02,
#          -4.1302e-01, -1.9430e-01,  2.2908e-01, -3.0950e-01, -1.1492e-01,
#          -2.7326e-01, -1.5878e-01, -4.8555e-01, -2.2570e-03,  3.6427e-01,
#          -8.9268e-01,  4.0152e-01, -3.4483e-02, -2.6120e-01,  1.7357e-01,
#          -5.8714e-01,  1.2500e-01,  3.3287e-01,  5.3705e-01, -7.2950e-02,
#          -9.4287e-01, -2.4366e-02,  3.5483e-01,  3.3880e-01, -4.0253e-03,
#          -2.4033e-01,  8.6481e-02,  1.5558e-01,  6.3489e-01,  3.5502e-01,
#          -7.8141e-03,  4.7522e-01,  5.5590e-02, -1.2406e-01, -3.7251e-01,
#           5.8332e-01, -1.7430e-01,  7.6183e-02, -4.8598e-01,  5.6179e-01,
#           4.6588e-01,  1.0647e-01,  7.9589e-01, -1.3691e-01,  1.6181e-02,
#          -1.8736e-01,  4.2970e-01, -3.2155e-01, -2.5580e-01,  2.3150e-01,
#          -8.8382e-02,  3.3640e-02,  4.3296e-02,  2.5957e-01, -3.7414e-02,
#          -2.0222e-01, -1.1016e-01, -1.2248e-01, -2.3362e-01,  3.4002e-02,
#           2.0465e-01, -5.3463e-01, -1.4158e-01, -1.5363e-01, -7.0915e-01,
#           2.5605e-01, -4.6095e-01, -1.0080e+00, -4.7916e-03,  5.8724e-01,
#           2.7761e-01, -9.8296e-02,  5.0573e-02,  4.6795e-01, -2.3215e-01,
#           4.3205e-01,  8.9082e-02,  1.0690e-01, -2.1431e-01,  2.5087e-01,
#          -6.0077e-01,  5.9497e-02, -3.5759e-01,  3.9667e-01,  4.4011e-01,
#          -7.5568e-01,  1.5246e-01, -5.5170e-01, -6.0951e-01,  2.8309e-02,
#          -2.9554e-01, -2.3058e-01,  3.4464e-01,  7.2006e-02,  2.2245e-01,
#           4.6076e-01,  5.0022e-02,  1.2248e-01, -9.7854e-02, -3.9058e-01,
#           7.7901e-02,  2.4518e-01, -8.1385e-02,  9.1565e-02, -1.5852e-02,
#          -4.5259e-02,  2.2278e-02, -2.7536e-01,  4.4392e-01,  3.2147e-01,
#           2.6685e-01,  3.8927e-01, -8.9637e-02, -1.5429e-01, -5.3037e-01,
#          -3.7925e-01,  9.3671e-02,  5.5047e-02, -1.8612e-01, -5.2965e-01,
#          -5.2107e-02,  2.3737e-01, -3.2585e-01,  6.2031e-01, -6.1760e-03,
#          -1.7091e-01, -1.8027e-01, -7.7878e-03,  2.6468e-01,  2.0330e-01,
#           4.2187e-01, -4.5997e-01, -4.8386e-02, -1.2262e-01,  6.7153e-02,
#           1.3823e-01,  2.5221e-01,  8.0991e-01,  5.8619e-01,  6.7332e-01,
#          -4.1664e-01,  7.0408e-02, -5.0332e-02, -5.0495e-01,  1.0251e-01,
#          -2.8832e-01, -2.3174e-01, -3.5866e-01, -4.5292e-01, -1.7239e-01,
#          -3.5694e-01, -2.0038e-01,  3.5725e-01,  2.9845e-01,  3.7109e-01,
#           5.1490e-01, -4.9599e-01, -1.6768e-01, -2.4795e-01, -5.6059e-02,
#          -2.4221e-01,  1.4885e-02,  1.6681e-01,  3.4767e-02, -3.9645e-01,
#           1.1245e-01,  4.5854e-01,  1.7045e-01, -2.5139e-01,  4.7993e-01,
#          -1.2335e+00,  4.4126e-02, -1.3951e-01, -1.3290e-01,  2.0486e-02,
#           4.3671e-01,  9.2521e-01,  9.6463e-02, -3.1544e-01, -4.9281e-01,
#          -3.5787e-01,  1.0986e-01, -2.2779e-01,  3.5474e-01,  5.4375e-01,
#           1.3176e-01,  8.4382e-02,  3.0393e-01,  6.5258e-01, -5.4010e-01,
#           5.3910e-03,  2.6530e-01,  1.5273e-01,  1.8205e-01, -4.4993e-01,
#          -5.1126e-02, -1.4226e-01, -1.4339e-01,  4.6857e-01, -4.8341e-01,
#          -5.1560e-01,  1.0630e-01, -3.6178e-01, -2.7622e-01, -2.7273e-01,
#           7.8592e-02,  2.2751e-01,  2.7308e-01,  4.1631e-01,  1.4083e-01,
#          -3.5819e-01,  2.2986e-01,  7.9156e-01, -3.5246e-01, -8.5015e-02]])

nn.Embeddingを使っても正しく分散表現を取得できてる!

あとはこんな感じにモデルに組み込んでいくだけ。

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(
            embeddings=TEXT.vocab.vectors,
            freeze=True
        )
        self.linear = nn.Linear(300, 10)

    def forward(self, x):
        x = self.embedding(x)
        x = self.linear(x)
        return x
    
model = Model()

参考

torchtext.vocab — torchtext 0.10.0 documentation
Embedding — PyTorch 1.9.1 documentation

コメント

タイトルとURLをコピーしました