LSTM その2


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day92


経緯についてはこちらをご参照ください。



■本日の進捗

  • LSTMの実装


■はじめに

今回も「ゼロから作るDeep Learning② 自然言語処理編(オライリー・ジャパン)」から学んでいきます。

今回は、前回学んだLSTMを実装していきたいと思います。

■LSTMクラス

LSTMの各ゲートを数式で振り返ります。

●新しい入力情報

$$ g = \tanh (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(g)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(g)} +
\mathrm{\boldsymbol{b}}^{(g)}
)$$

●忘却ゲート(Forget Gate)

$$ f = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(f)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(f)} +
\mathrm{\boldsymbol{b}}^{(f)}
)$$

●入力ゲート(OutPut Gate)

$$ i = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(i)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(i)} +
\mathrm{\boldsymbol{b}}^{(i)}
)$$

●出力ゲート(Output Gate)

$$ o = \sigma (
\mathrm{\boldsymbol{x}}_t
\mathrm{\boldsymbol{W}}_x^{(o)} +
\mathrm{\boldsymbol{h}}_{t-1}
\mathrm{\boldsymbol{W}}_h^{(o)} +
\mathrm{\boldsymbol{b}}^{(o)}
)$$

●セル状態

$$
c_t = f \odot c_{t-1} + g \odot i
$$

●隠れ状態

$$ h_t =
o \odot \tanh (c_t)
$$

これらを踏まえて、LSTMの処理を実装していきます。

まずは、入力(x)に関する重み(Wx)、隠れ状態(h)に関する重み(Wh)、バイアス(b)の3つを引数として受け取ります。

ここで、4つのゲート(f, i, o, g)でそれぞれ用いる重み行列を考えると、各ゲートで同様の演算を行うため、まとめて計算できるように上記の引数も基本的には4つ分の形状に対応できるようにします。

入力の次元数をD、隠れ状態の次元数をHとすると、Wx(D, 4H)、Wh(H, 4H)、b(4H, )となります。

class LSTM:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

順伝播の場合は、引数としてバッチサイズ(N)×入力次元数(D)の形状を持つ入力データ(x)、バッチサイズ(N)×隠れ状態次元数の形状を持つ前の時刻での隠れ状態(h_prev)と前の時刻でのセル状態(c_prev)を受け取ります。

入力ゲートの計算(A)を重み行列の合計を用いて行ったら、各ゲートにその結果を分配します。この結果に対して各ゲートで活性化関数を適用したら、次の時刻へのセル状態と隠れ状態を計算し、self.cacheに保存したら終了です。

    def forward(self, x, h_prev, c_prev):
        Wx, Wh, b = self.params
        N, H = h_prev.shape

        A = np.dot(x, Wx) + np.dot(h_prev, Wh) + b

        f = A[:, :H]
        g = A[:, H:2*H]
        i = A[:, 2*H:3*H]
        o = A[:, 3*H:]

        f = sigmoid(f)
        g = np.tanh(g)
        i = sigmoid(i)
        o = sigmoid(o)

        c_next = f * c_prev + g * i
        h_next = o * np.tanh(c_next)

        self.cache = (x, h_prev, c_prev, i, f, g, o, c_next)
        return h_next, c_next

逆伝播の場合は、バッチサイズ(N)×隠れ状態次元数(H)の形状を持つ隠れ状態の誤差(dh_next)とセル状態の誤差(dc_next)を引数として受け取ったら、セル状態の勾配(ds)を用いて、各ゲートの勾配を算出します。順伝播では結果を各ゲートに分配していたので、これを元に戻すためにnp.hstackにより結合した行列(dA)を求めます。

最後に重みの勾配(dWh, dWx)とバイアスの勾配(db)を求めて、次の層へ渡す誤差(dx, dh_prev, dc_prev)を返したら終了です。

    def backward(self, dh_next, dc_next):
        Wx, Wh, b = self.params
        x, h_prev, c_prev, i, f, g, o, c_next = self.cache

        tanh_c_next = np.tanh(c_next)

        ds = dc_next + (dh_next * o) * (1 - tanh_c_next ** 2)

        dc_prev = ds * f

        di = ds * g
        df = ds * c_prev
        do = dh_next * tanh_c_next
        dg = ds * i

        di *= i * (1 - i)
        df *= f * (1 - f)
        do *= o * (1 - o)
        dg *= (1 - g ** 2)

        dA = np.hstack((df, dg, di, do))

        dWh = np.dot(h_prev.T, dA)
        dWx = np.dot(x.T, dA)
        db = dA.sum(axis=0)

        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db

        dx = np.dot(dA, Wx.T)
        dh_prev = np.dot(dA, Wh.T)

        return dx, dh_prev, dc_prev



■時系列LSTMクラス

RNNの場合と同様に、先ほどのLSTMクラスを時系列処理に対応させるためのクラスを用意しておきます。以前と同様に時間方向にシーケンスを処理する構造と、隠れ状態(ここではセル状態も追加されることに注意)を引き継げる構造にしておく必要があります。

まずは、入力の重み(Wx:D×4H)、隠れ状態の重み(Wh:H×4H)、バイアス(b:4H)そして隠れ状態を次のバッチへ引き継ぐかどうかを制御するブーリアン(stateful)を引数として受け取り、各種パラメータや現在の隠れ状態、セル状態等を初期化しておきます。

class TimeLSTM:
    def __init__(self, Wx, Wh, b, stateful=False):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.layers = None

        self.h, self.c = None, None
        self.dh = None
        self.stateful = stateful

順伝播の場合は、バッチサイズ(N)×タイムステップ数(T)×入力次元数(D)の形状を持つ時系列入力データ(xs)を引数として受け取り、出力となる各タイムステップの隠れ状態(hs)を初期化しておきます。

タイムステップ(t)ごとにLSTMクラスのインスタンスを生成し、順伝播を実行します。最後に各タイムステップをまとめた隠れ状態(hs)を返したら終了です。

    def forward(self, xs):
        Wx, Wh, b = self.params
        N, T, D = xs.shape
        H = Wh.shape[0]

        self.layers = []
        hs = np.empty((N, T, H), dtype='f')

        if not self.stateful or self.h is None:
            self.h = np.zeros((N, H), dtype='f')
        if not self.stateful or self.c is None:
            self.c = np.zeros((N, H), dtype='f')

        for t in range(T):
            layer = LSTM(*self.params)
            self.h, self.c = layer.forward(xs[:, t, :], self.h, self.c)
            hs[:, t, :] = self.h

            self.layers.append(layer)

        return hs

逆伝播の場合は、順伝播での出力(hs)を元に、勾配を逆方向に伝播させ、入力データの勾配(dxs)を返します。

まず、隠れ状態の勾配(dhs:N×T×H)を引数として受け取り、入力データの勾配(dxs)や次のタイムステップへの勾配である隠れ状態の勾配(dh)、セル状態の勾配(dc)を初期化します。

reverced関数を用いて、時間方向に逆順にループさせたら、各タイムステップでのLSTMクラスのインスタンスを呼び出し、それぞれで逆伝播を実行させます。

最後に各タイムステップでの勾配を累積してgradsに格納したら、入力データの勾配(dxs)を返り値として返して終了です。

    def backward(self, dhs):
        Wx, Wh, b = self.params
        N, T, H = dhs.shape
        D = Wx.shape[0]

        dxs = np.empty((N, T, D), dtype='f')
        dh, dc = 0, 0

        grads = [0, 0, 0]
        for t in reversed(range(T)):
            layer = self.layers[t]
            dx, dh, dc = layer.backward(dhs[:, t, :] + dh, dc)
            dxs[:, t, :] = dx
            for i, grad in enumerate(layer.grads):
                grads[i] += grad

        for i, grad in enumerate(grads):
            self.grads[i][...] = grad
        self.dh = dh
        return dxs

現在の隠れ状態とセル状態を外部から設定したりリセットするための関数を追加しておきます。

    def set_state(self, h, c=None):
        self.h, self.c = h, c

    def reset_state(self):
        self.h, self.c = None, None



■おわりに

今回は前回学んだLSTM手法をLSTMクラスとして実装し、時系列処理に対応する形で使えるようにしました。明日はこれらのクラスを用いて実際に学習させてみたいと思います。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
  3. 斎藤 康毅. ゼロから作るDeep Learning② 自然言語処理編. オライリー・ジャパン. 2018. 432p.
  4. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  5. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html
  6. PyTorch documentation. pytorch.org. https://pytorch.org/docs/stable/index.html
  7. Keiron O’Shea, Ryan Nash. An Introduction to Convolutional Neural Networks. https://ar5iv.labs.arxiv.org/html/1511.08458
  8. API Reference. scipy.org. 2024. https://docs.scipy.org/doc/scipy/reference/index.html


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です