AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day91
経緯についてはこちらをご参照ください。
■本日の進捗
- LSTMの理論を理解
■はじめに
今回も「ゼロから作るDeep Learning② 自然言語処理編(オライリー・ジャパン)」から学んでいきます。
今回は、再帰型ニューラルネットワークを拡張したLSTMについて学びたいと思います。
■LSTM
LSTM(Long Short-Term Memory)とは、再帰型ニューラルネットワークで起こりやすいとされる勾配消失問題を抑制し、長い時系列情報を記憶を保ったまま学習できる拡張的な手法で、誤差逆伝播で勾配が指数関数的に小さくなり過去の情報が欠損していくことを防ぎます。
■再帰型ニューラルネットワークの問題
再帰型ニューラルネットワークでは過去の情報を保持できるようになりましたが、長文を学習するには常に関連がある単語までをその間も含めて記憶しておく必要があり、膨大なメモリと計算コストを必要としてしまいます。(僕らが難しい本を読む時に数行前の文を忘れてしまい現在の文章の意味が分からなくなってしまうのに似ています。この時は任意の位置まで返り読みすることになりますが、これは計算コストにあたるでしょうか。機械的にこれを行うには忘れないように常に数行前まで振り返る必要がありそうですし、もし必要な単語があと1単語過去のものであればお手上げです。では人間のように不要なものを忘れていったら?これを行うのがLSTMです。)
また、このような長文への対応は単にコストの問題だけでなく、勾配爆発や勾配損失といったもっと致命的な問題が潜んでいます。
再帰型ニューラルネットワークでは下記の式で示されるような隠れ状態(ht)を計算して順伝播を行うのでした。
$$ h_t = \sigma ( \mathrm{\boldsymbol{h}}_{t-1} \mathrm{\boldsymbol{W}}_h + \mathrm{\boldsymbol{x}}_t \mathrm{\boldsymbol{W}}_x + \mathrm{\boldsymbol{b}} ) $$
ここで、活性化関数(σ)はアルゴリズムに応じてtanhやReLUなどが選ばれます。
誤差逆伝播においてはRNN層内で損失の勾配が累積されていくことになるので、T番目のRNNからの逆伝播は、
$$ \frac{\partial L}{\partial \mathrm{\boldsymbol{W}}_h} = \displaystyle \sum_{k=1}^T \frac{\partial L}{\partial \mathrm{\boldsymbol{h}}_k} \cdot \frac{\partial \mathrm{\boldsymbol{h}}_k}{\partial \mathrm{\boldsymbol{h}}_{k-1}} \cdot \frac{\partial \mathrm{\boldsymbol{h}}_{k-1}}{\partial \mathrm{\boldsymbol{W}}_h} $$
ここで、右辺第2項は、
$$ \frac{\partial \mathrm{\boldsymbol{h}}_k}{\partial \mathrm{\boldsymbol{h}}_{k-1}} =
\mathrm{\boldsymbol{W}}_h \cdot
\sigma’ ( \mathrm{\boldsymbol{h}}_{k-1} \mathrm{\boldsymbol{W}}_h + \mathrm{\boldsymbol{x}}_k \mathrm{\boldsymbol{W}}_x )
$$
つまり誤差逆伝播の出力に上式の時間方向への累乗が含まれることになるので、
$$
\frac{\partial L}{\partial \mathrm{\boldsymbol{h}}_t } =
\displaystyle \prod_{k=t}^T
\mathrm{\boldsymbol{W}}_h \cdot
\sigma’ (z_k)
$$
ここで重みが累乗されていくので、重みが1より大きければ指数関数的に増大していき、重みが1より小さければ指数関数的に減少していくので、RNN層で勾配爆発や勾配消失が起こりやすいというわけです。
■セル状態
LSTMでは効率的な記憶の保持のために、セル状態(cell state)という長期的な情報の保持を担うセルを持ちます。このセル状態が通常の隠れ状態とは別で情報を保持することで、情報が劣化(勾配爆発や勾配消失)することなく長期間に渡って安定した伝播が可能になります。
セル状態を更新して、重要でない情報を忘れたり、新しく学んだ情報を追加したり、次の隠れ状態へどの情報を渡すのかをコントロールするには、ゲートと呼ばれる機構を実装する必要があります。
■ゲート
LSTMは考え方はほとんどRNNと同じですが、セル状態とその制御のためのゲートを持っていることが大きな特徴なので、ゲート付き再帰型ニューラルネットワークとも呼ばれます。
ゲートはRNN層内での情報の流れを制御する役割を担う数理機構で、ゲート内の重みを用いて特定の情報を通すか遮断するかを連続的に、そして動的に判断する仕組みを持ちます。もちろんここでの重みもモデルが自動的に情報の重要性を理解して学習していきます。
新しく記憶する情報は、tanhを用いて次のように定義されます。
$$ g = \tanh (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(g)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(g)} +
\mathrm{\boldsymbol{b}}^{(g)}
)$$
LSTMは3つのゲートを持ちます。
●忘却ゲート(Forget Gate)
セル状態の保持と忘却を制御します。
$$ f = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(f)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(f)} +
\mathrm{\boldsymbol{b}}^{(f)}
)$$
●入力ゲート(OutPut Gate)
入力情報からセル状態に追加する情報量を制御します。つまり先ほどの新しい情報(g)の追加を制御するためのものです。
$$ i = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(i)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(i)} +
\mathrm{\boldsymbol{b}}^{(i)}
)$$
●出力ゲート(Output Gate)
セル状態から次の隠れ状態へ反映させる出力を制御します。
$$ o = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(o)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(o)} +
\mathrm{\boldsymbol{b}}^{(o)}
)$$
●セル状態と隠れ状態
これらはシグモイド関数(σ)によって変換され、0~1の値を取ります。(これはゲートを何%開けるかどうかと考えられます。)
結果として、現在の時刻(t)におけるセル状態(ct)は、前の時刻(t-1)におけるセル状態(ct-1)から忘れるべき忘却ゲートをかけたものと、追加すべき入力ゲートの値を足し合わせたものになります。
$$
c_t = f \odot c_{t-1} + g \odot i
$$
出力される隠れ状態は、出力ゲートを通して出力されます。
$$ h_t =
o \odot \tanh (c_t)
$$
■おわりに
今回は再帰型ニューラルネットワークの問題点を理解し、それを解決するための手法であるLSTMを学びました。明日はこの内容を実装していこうと思います。
■参考文献
- Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
- 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
- 斎藤 康毅. ゼロから作るDeep Learning② 自然言語処理編. オライリー・ジャパン. 2018. 432p.
- ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
- API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html
- PyTorch documentation. pytorch.org. https://pytorch.org/docs/stable/index.html
- Keiron O’Shea, Ryan Nash. An Introduction to Convolutional Neural Networks. https://ar5iv.labs.arxiv.org/html/1511.08458
- API Reference. scipy.org. 2024. https://docs.scipy.org/doc/scipy/reference/index.html