これはPytorchで日本語の学習済みBERTモデルを読み込み、文章ベクトル(Sentence Embedding)を計算するためのコードです。
詳細は下記ブログを参考ください。
PytorchでBERTの日本語学習済みモデルを利用する - 文章埋め込み編
- 日本語の学習済みBERTモデル: BERT日本語Pretrainedモデル - KUROHASHI-KAWAHARA LAB
- BERT実装: pytorch-pretrained-BERT
- 形態素解析器: JUMAN++ - KUROHASHI-KAWAHARA LAB
京都大学の黒橋・河原研究室が公開している「BERT日本語Pretrainedモデル」を利用します。下記ウェブページからモデルファイルをダウンロードして解凍してください。
BERT日本語Pretrainedモデル - KUROHASHI-KAWAHARA LAB
Juman++をインストールします。インストール方法については、下記の公式レポジトリを参照ください。
ku-nlp/jumanpp: Juman++ (a Morphological Analyzer Toolkit)
なお、macOSならばHomebrewを使って下記のように簡単にインストールできます。
$ brew install jumanpp
pytorch-pretrained-bert
およびpyknp
をインストールします。
$ pip install pytorch-pretrained-bert
$ pip install pyknp
なお、ここではPytorchをBERT実装に利用するので、Pytorchはインストールされているものとします。
本レポジトリのbert_juman.py
からBertWithJumanModel
クラスをインポートします。クラスの引数には、ダウンロードした日本語の学習済みBERTモデルのディレクトリを指定します。必要なファイルはpytorch_model.bin
とvocab.txt
のみです。
In []: from bert_juman import BertWithJumanModel
In []: bert = BertWithJumanModel("/path/to/Japanese_L-12_H-768_A-12_E-30_BPE")
In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([ 2.22642735e-01, -2.40221739e-01, 1.09303640e-02, -1.02307117e+00,
1.78834641e+00, -2.73566216e-01, -1.57942638e-01, -7.98571169e-01,
-2.77438164e-02, -8.05811465e-01, 3.46736580e-01, -7.20409870e-01,
1.03382647e-01, -5.33944130e-01, -3.25344890e-01, -1.02880754e-01,
2.26500735e-01, -8.97880018e-01, 2.52314955e-01, -7.09809303e-01,
[...]
またget_sentence_embedding()
の引数には、文章ベクトルを計算するのに利用するBERTの隠れ層の位置pooling_layer
と、プーリングの方法pooling_strategy
が指定できます。pooling_layer
は-1
で最終層、-2
で最終層の手前の層となります。また、pooling_strategy
には
REDUCE_MEAN
: 要素ごとにaverage-poolingREDUCE_MAX
: 要素ごとにmax-poolingREDUCE_MEAN_MAX
:REDUCE_MEAN
とREDUCE_MAX
を結合したものCLS_TOKEN
: [CLS]トークンのベクトルをそのまま利用
が選択できます。
In []: bert.get_sentence_embedding("吾輩は猫である。",
...: pooling_layer=-1,
...: pooling_strategy="REDUCE_MAX")
...:
Out[]:
array([ 1.2089624 , 0.6267309 , 0.7243419 , -0.12712255, 1.8050476 ,
0.43929055, 0.605848 , 0.5058241 , 0.8335829 , -0.26000524,
[...]
これらのパラメータはhanxiao/bert-as-serviceを参考にしています。
In []: bert = BertWithJumanModel("../Japanese_L-12_H-768_A-12_E-30_BPE", use_cuda=True)
In []: bert.get_sentence_embedding("吾輩は猫である。")
Out[]:
array([-4.25627649e-01, -3.42006773e-01, -7.15175271e-02, -1.09820020e+00,
1.08186746e+00, -2.35576674e-01, -1.89862609e-01, -5.50959229e-01,