ミニバッチ学習


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day54


経緯についてはこちらをご参照ください。



■本日の進捗

  • ミニバッチ学習を理解


■はじめに

今回も「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装(オライリー・ジャパン)」で、深層学習を学んでいきます。

今回は、前回導入した損失関数の計算を効率的に実施する手法を組み込んでいきます。

■ミニバッチ学習

ミニバッチ学習とは、データセット全体を一度の学習で全て用いずに小さく分割したデータ毎に学習させる手法です。反対に言えば、学習(重みやバイアスの更新)をあるデータ単位で行うのではなく、小さなひとまとめのデータ(ミニバッチ)毎に行うということに他なりません。

前回導入した交差エントロピー誤差を振り返ってみます。

$$ E =\ – \displaystyle \sum_{k} t_k \log y_k $$

この誤差は、k(例えばk=1)番目の値での予測値と正解の値の誤差を表しているのでした。データが複数(例えばN個)ある時は下記のように記述できます。

$$ E(t, y) =\ – \displaystyle \sum_{k} t_k \log y_k $$

$$ L = \frac{1}{N} \displaystyle \sum_{n} E(t_n, y_n) = \ – \frac{1}{N} \displaystyle \sum_{n} \sum_{k} t_{nk} \log y_{nk} $$

これはロジスティック損失とも呼ばれ、データセットが大規模な場合には損失関数の計算に膨大な計算コストがかかるのみならず、メモリのオーバーフローや収束速度の低下が懸念されます。

データセット全体を小さいセットに分割することで、これを解決しようという訳です。

一般的にはミニバッチを分割数分だけパラメータを更新しながら繰り返し、すべての教師データを使い終わる回数をエポックという単位で表現します。場合によってはこれを数エポック繰り返すこともあります。

今回は一組のミニバッチをランダムで選定し、一度だけ学習する場合を考えます。単にメモリ使用量や計算コストの削減を目的とします。

その前に、大規模データセットを導入します。

■MNISTデータセット

MNISTデータセットとは、手書きの0~9までの数字の画像を集めたもので、下記のページからダウンロードできます。

https://yann.lecun.com/exdb/mnist/

ただし、参考図書のソースがGitHubに公開されていて、ダウンロードのみならずone-hot化までクラス化されているので通常はこちらの方が便利です。(dataset/mnist.pyを参照)

https://github.com/oreilly-japan/deep-learning-from-scratch

今回はありがたくこちらを使用させていただきます。

import sys
import numpy as np
sys.path.append("./")
from work.mnist import load_mnist

(X_train, y_train), (X_test, y_test) = load_mnist(normalize=True, one_hot_label=True)

print("X_train.shape :{}".format(X_train.shape))
train_size = X_train.shape[0]

Pathはご自身の環境に応じて変更してください。mnist.pyがある場所を指定すればいいです。

■ミニバッチの実装

先程の結果から、教師データの数は60000でした。これを全て用いて損失関数を計算するのは大変そうですが、とりあえずやってみます。

import sys
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("./")
from work.mnist import load_mnist

def relu(x):
    return np.maximum(0, x)

def softmax(z):
    c = np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z - c)
    sum_exp_z = np.sum(exp_z, axis=1, keepdims=True)
    y = exp_z / sum_exp_z
    return y

def a(x, W, b):
    a = np.dot(x, W) + b
    return a

def init_network():
    network = {}
    network['W1'] = np.random.randn(784, 50) * 0.1
    network['b1'] = np.zeros(50)
    network['W2'] = np.random.randn(50, 100) * 0.1
    network['b2'] = np.zeros(100)
    network['W3'] = np.random.randn(100, 10) * 0.1
    network['b3'] = np.zeros(10)
    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = a(x, W1, b1)
    z1 = relu(a1)
    a2 = a(z1, W2, b2)
    z2 = relu(a2)
    a3 = a(z2, W3, b3)
    y = softmax(a3)

    return y

(X_train, y_train), (X_test, y_test) = load_mnist(normalize=True, one_hot_label=True)

print("X_train.shape :{}".format(X_train.shape))
train_size = X_train.shape[0]

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta)) / y.shape[0]

network = init_network()
y = forward(network, X_train)

loss = cross_entropy_error(y, y_train)
print("cross entropy error = {:.3f}".format(loss))

思ったより全然軽かったです。GridSearchCVが懐かしいです。

それはともかく折角なので(本質的には違いますが)ミニバッチ化してみます。

import sys
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("./")
from work.mnist import load_mnist

def relu(x):
    return np.maximum(0, x)

def softmax(z):
    c = np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z - c)
    sum_exp_z = np.sum(exp_z, axis=1, keepdims=True)
    y = exp_z / sum_exp_z
    return y

def a(x, W, b):
    a = np.dot(x, W) + b
    return a

def init_network():
    network = {}
    network['W1'] = np.random.randn(784, 50) * 0.1
    network['b1'] = np.zeros(50)
    network['W2'] = np.random.randn(50, 100) * 0.1
    network['b2'] = np.zeros(100)
    network['W3'] = np.random.randn(100, 10) * 0.1
    network['b3'] = np.zeros(10)
    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = a(x, W1, b1)
    z1 = relu(a1)
    a2 = a(z1, W2, b2)
    z2 = relu(a2)
    a3 = a(z2, W3, b3)
    y = softmax(a3)

    return y

(X_train, y_train), (X_test, y_test) = load_mnist(normalize=True, one_hot_label=True)

print("X_train.shape :{}".format(X_train.shape))
train_size = X_train.shape[0]

batch_size = 10
np.random.seed(8)
batch_mask = np.random.choice(train_size, batch_size)
X_batch = X_train[batch_mask]
y_batch = y_train[batch_mask]

print("batch_mask :{}".format(batch_mask))

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta)) / y.shape[0]

network = init_network()
y = forward(network, X_batch)

loss = cross_entropy_error(y, y_batch)
print("cross entropy error = {:.3f}".format(loss))

60000個のデータから10個をミニバッチとして取り出してニューラルネットワークにかけました。数値は若干変動しています。

ミニバッチの数を増やしてみます。一通り記載するわけにもいかないので、今回は一気に10000個まで上げました。

import sys
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("./")
from work.mnist import load_mnist

def relu(x):
    return np.maximum(0, x)

def softmax(z):
    c = np.max(z, axis=1, keepdims=True)
    exp_z = np.exp(z - c)
    sum_exp_z = np.sum(exp_z, axis=1, keepdims=True)
    y = exp_z / sum_exp_z
    return y

def a(x, W, b):
    a = np.dot(x, W) + b
    return a

def init_network():
    network = {}
    network['W1'] = np.random.randn(784, 50) * 0.1
    network['b1'] = np.zeros(50)
    network['W2'] = np.random.randn(50, 100) * 0.1
    network['b2'] = np.zeros(100)
    network['W3'] = np.random.randn(100, 10) * 0.1
    network['b3'] = np.zeros(10)
    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = a(x, W1, b1)
    z1 = relu(a1)
    a2 = a(z1, W2, b2)
    z2 = relu(a2)
    a3 = a(z2, W3, b3)
    y = softmax(a3)

    return y

(X_train, y_train), (X_test, y_test) = load_mnist(normalize=True, one_hot_label=True)

print("X_train.shape :{}".format(X_train.shape))
train_size = X_train.shape[0]

batch_size = 10000
np.random.seed(8)
batch_mask = np.random.choice(train_size, batch_size)
X_batch = X_train[batch_mask]
y_batch = y_train[batch_mask]

print("batch_mask :{}".format(batch_mask))

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta)) / y.shape[0]

network = init_network()
y = forward(network, X_batch)

loss = cross_entropy_error(y, y_batch)
print("cross entropy error = {:.3f}".format(loss))



■おわりに

今回は損失関数の算出にミニバッチを導入しました。大規模なデータセットに対して効果を発揮します。また、1エポックにつき複数回(ミニバッチ数分)のパラメータ更新が行われることになるので、収束が早く場合によっては精度向上が期待できます。

これは人間に例えるならば、英単語帳を一度だけ(もちろん振り返らずに)頭から最後まで目を通してどのくらい覚えられるかということです。それよりも単語帳を10分割して1セグメントずつ覚える方が効率的に覚えられると聞いたことがあります。また、1冊通しを何度も繰り返すことでも記憶の定着がしやすいと聞いたこともあります。前者は正しくミニバッチ学習で、後者は損失関数を用いたパラメータの更新頻度に例えられるのではないでしょうか。学習効率を考えるならば、機械学習も人間も、大規模なデータを複数に分割して何度も記憶を修正しながら学習していくのが良いということです。え、ちょっと違う?ニューラルネットワークの気持ちは、まあ、そんな感じですよきっと。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
  3. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  4. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です