勾配降下法


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day55


経緯についてはこちらをご参照ください。



■本日の進捗

  • 勾配降下法を理解


■はじめに

今回も「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装(オライリー・ジャパン)」で、深層学習を学んでいきます。

今回は、ついに損失関数を最小化して自動で学習させる方法を学びます。

■勾配降下法

勾配降下法(gradient descent method)とは、損失関数の値を徐々に減らしていくことでニューラルネットワークのパラメータ(重みやバイアス)を最適化するアルゴリズムです。

これまでは損失関数の値を算出するところまで実装してきましたが、これを上手く最小化することで初めて誤差を減らして良い学習モデルを構築することができます。

$$ x = x – \eta \frac{ \partial f }{ \partial x } $$

ここで、ηは学習率(learning rate)と呼ばれ、パラメータの更新の大きさを制御する定数です。この値は人間が事前に決めることになります。偏微分項は損失関数fの勾配で、現在の値から最も大きく増加する方向を示します。この逆方向を取ることで、損失を減らせるというものです。

交差エントロピー誤差を勾配降下法で最小化する手法でニューラルネットワークに実装してみます。

import sys
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("./")
from work.mnist import load_mnist

def relu(x):
    return np.maximum(0, x)

def softmax(z):
    c = np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z - c)
    sum_exp_z = np.sum(exp_z, axis=1, keepdims=True)
    y = exp_z / sum_exp_z
    return y

def a(x, W, b):
    return np.dot(x, W) + b

def init_network():
    network = {}
    network['W1'] = np.random.randn(784, 50) * 0.1
    network['b1'] = np.zeros(50)
    network['W2'] = np.random.randn(50, 100) * 0.1
    network['b2'] = np.zeros(100)
    network['W3'] = np.random.randn(100, 10) * 0.1
    network['b3'] = np.zeros(10)
    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = a(x, W1, b1)
    z1 = relu(a1)
    a2 = a(z1, W2, b2)
    z2 = relu(a2)
    a3 = a(z2, W3, b3)
    y = softmax(a3)

    return y

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta)) / y.shape[0]

def numerical_gradient(f, x):
    h = 1e-4
    grad = np.zeros_like(x)

    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        tmp_val = x[idx]
        
        x[idx] = tmp_val + h
        fxh1 = f(x)  # f(x + h)
        
        x[idx] = tmp_val - h
        fxh2 = f(x)  # f(x - h)
        
        grad[idx] = (fxh1 - fxh2) / (2 * h)
        x[idx] = tmp_val
        it.iternext()
    
    return grad

def numerical_gradient_network(network, x, t):
    grads = {}
    
    def loss_W(W): 
        y = forward(network, x)
        return cross_entropy_error(y, t)
    
    grads['W1'] = numerical_gradient(loss_W, network['W1'])
    grads['b1'] = numerical_gradient(loss_W, network['b1'])
    grads['W2'] = numerical_gradient(loss_W, network['W2'])
    grads['b2'] = numerical_gradient(loss_W, network['b2'])
    grads['W3'] = numerical_gradient(loss_W, network['W3'])
    grads['b3'] = numerical_gradient(loss_W, network['b3'])
    
    return grads

(X_train, y_train), (X_test, y_test) = load_mnist(normalize=True, one_hot_label=True)

print("X_train.shape :{}".format(X_train.shape))
train_size = X_train.shape[0]

batch_size = 10
np.random.seed(8)
batch_mask = np.random.choice(train_size, batch_size)
X_batch = X_train[batch_mask]
y_batch = y_train[batch_mask]

lr = 0.01
step_num = 100
network = init_network()

for i in range(step_num):
    grads = numerical_gradient_network(network, X_batch, y_batch)
    
    for key in ('W1', 'b1', 'W2', 'b2', 'W3', 'b3'):
        network[key] -= lr * grads[key]

    y = forward(network, X_batch)
    loss = cross_entropy_error(y, y_batch)
    print(f"Step {i+1}, Loss: {loss}")

MNISTデータセットを用いたニューラルネットワークの損失関数を算出し、勾配降下法で100ステップ学習させ、損失関数の変化を見てみました。

ステップごとに徐々に交差エントロピー誤差の値が下がっていて、ニューラルネットワークが上手く学習できている様子が確認できました。

■おわりに

遂に機械学習ライブラリを用いずにニューラルネットワークがデータを学習するところまで実装することができました。

ちなみに前回学んだミニバッチ化の要素を少し入れていますが、バッチ(全データ)で学習させてみたところ、爆裂に重かった(Ryzen9をもってしても1ステップも動かない…)のでこのように書いてみました。並列化やGPU化ができれば恐らくそれほど問題ではなかったのかもしれませんが、その辺もやっぱり既存ライブラリのメリットだと実感できます。しかし、自分で実装してみるというのは何にも代えがたい達成感と理解が深まる実感をずっしりと感じます。この記事に巡り合ってくださった方がもしいらっしゃったら是非参考文献(と本稿も入れていただければ光栄です)を参考に書いてみてください。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
  3. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  4. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です