AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day85
経緯についてはこちらをご参照ください。
■本日の進捗
- 類推問題を理解
■はじめに
今回も「ゼロから作るDeep Learning② 自然言語処理編(オライリー・ジャパン)」から学んでいきます。
今回は、せっかく作った推論ベースモデルであるCBoWを用いて類推問題と呼ばれる推論タスク行ってみます。
■類推問題
類推問題(Analogy)とは、単語の関係性を用いた推論タスクで、「AとBならばCとD」という言語的な関係を機械学習モデルが理解できるかを評価するタスクです。
意味的関係は「王と女王ならば男と?」といった意味的に関係がある単語を類推するタスクです。人間であれば答えが「女」であることはすぐに分かります。
文法的関係は「走ると走ったならば行くと?」といった文法的に関係がある単語を類推するタスクです。これも過去形で並んでいるので答えは「行った」であろうことは簡単に類推できます。
しかしニューラルネットワークがこのタスクをこなせるかはモデルが言語を意味的に、あるいは文法的に理解しているかを評価するために非常に良い指標になります。
直観的には、これまで行ってきたようにコーパスの各単語を分散表現に落とし込み、コサイン類似度でスコアを割り出せば類推できそうです。
■Analogyクラス
類推を行うためのクラスを実装していきます。
このクラスの引数には、類推の元となる単語(a, b, c)、単語(word_to_id)とインデックス(id_to_word)、分散表現になった(word_matrix)、類似単語数(top)、正解とされる単語(answer)を取ります。
まずは単語の存在を確認します。類推の元となる単語a, b, cがword_to_idに存在するかを確認して、存在しなければエラー処理を行います。
def analogy(a, b, c, word_to_id, id_to_word, word_matrix, top=5, answer=None):
for word in (a, b, c):
if word not in word_to_id:
print('%s is not found' % word)
return
次に単語a, b, cに対してそれぞれ対応するベクトルを取得してから、Analogyベクトルを計算します。
$$ \mathrm{query\_vec} = \mathrm{\boldsymbol{b}}\ -\ \mathrm{\boldsymbol{a}} \ +\ \mathrm{\boldsymbol{c}} $$
このAnalogyベクトルを正規化(このクラスは別途定義)して類似度を計算する際のスケール影響を取り除いてからコサイン類似度を計算します。
print('\n[analogy] ' + a + ':' + b + ' = ' + c + ':?')
a_vec, b_vec, c_vec = word_matrix[word_to_id[a]], word_matrix[word_to_id[b]], word_matrix[word_to_id[c]]
query_vec = b_vec - a_vec + c_vec
query_vec = normalize(query_vec)
similarity = np.dot(word_matrix, query_vec)
正解単語オプションが設定されている時のみ、その類似度を出力します。
if answer is not None:
print("==>" + answer + ":" + str(np.dot(word_matrix[word_to_id[answer]], query_vec)))
最後に最も類似する単語を取得して、最大類似単語数に達したら終了します。
ループは類似度の降順でソートしてその時のインデックスで回します。
(-1 * similarity).argsort()
類似度がNaNの場合、または入力単語(a, b, c)と同じ単語であればスキップします。
count = 0
for i in (-1 * similarity).argsort():
if np.isnan(similarity[i]):
continue
if id_to_word[i] in (a, b, c):
continue
print(' {0}: {1}'.format(id_to_word[i], similarity[i]))
count += 1
if count >= top:
return
■Analogyの簡易実装
先ほどのAnalogyクラスを用いて学習済みの重みを用いた類推問題を処理するコードをしてみます。
import sys
import os
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
import pickle
import collections
def normalize(x):
if x.ndim == 2:
s = np.sqrt((x * x).sum(1))
x /= s.reshape((s.shape[0], 1))
elif x.ndim == 1:
s = np.sqrt((x * x).sum())
x /= s
return x
def analogy(a, b, c, word_to_id, id_to_word, word_matrix, top=5, answer=None):
for word in (a, b, c):
if word not in word_to_id:
print('%s is not found' % word)
return
print('\n[analogy] ' + a + ':' + b + ' = ' + c + ':?')
a_vec, b_vec, c_vec = word_matrix[word_to_id[a]], word_matrix[word_to_id[b]], word_matrix[word_to_id[c]]
query_vec = b_vec - a_vec + c_vec
query_vec = normalize(query_vec)
similarity = np.dot(word_matrix, query_vec)
if answer is not None:
print("==>" + answer + ":" + str(np.dot(word_matrix[word_to_id[answer]], query_vec)))
count = 0
for i in (-1 * similarity).argsort():
if np.isnan(similarity[i]):
continue
if id_to_word[i] in (a, b, c):
continue
print(' {0}: {1}'.format(id_to_word[i], similarity[i]))
count += 1
if count >= top:
return
pkl_file = 'cbow_params.pkl'
with open(pkl_file, 'rb') as f:
params = pickle.load(f)
word_vecs = params['word_vecs']
word_to_id = params['word_to_id']
id_to_word = params['id_to_word']
analogy('king', 'queen', 'man', word_to_id, id_to_word, word_vecs)
重みは参考文献で公開しているGitHubからダウンロードできます。
https://github.com/oreilly-japan/deep-learning-from-scratch-2
■おわりに
今回は類推問題の処理と学習済みの重みデータを用いた類推処理を実装してみました。実際にPTBデータセットを学習してからその学習結果を用いて類推を行うモデルは明日実装する予定です。(上記のコードで出力される類推結果も明日に残しておきます)
■参考文献
- Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
- 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
- 斎藤 康毅. ゼロから作るDeep Learning② 自然言語処理編. オライリー・ジャパン. 2018. 432p.
- ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
- API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html
- PyTorch documentation. pytorch.org. https://pytorch.org/docs/stable/index.html
- Keiron O’Shea, Ryan Nash. An Introduction to Convolutional Neural Networks. https://ar5iv.labs.arxiv.org/html/1511.08458
- API Reference. scipy.org. 2024. https://docs.scipy.org/doc/scipy/reference/index.html