損失関数


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day53


経緯についてはこちらをご参照ください。



■本日の進捗

  • 平均二乗誤差を理解
  • 交差エントロピー誤差を理解


■はじめに

今回も「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装(オライリー・ジャパン)」で、深層学習を学んでいきます。

今回はニューラルネットワークに、学習において重要な損失関数を導入していきます。

■損失関数

これまで構築してきたニューラルネットワークを振り返ってみましょう。前回は最後にこのニューラルネットワークの問題点を提起していました。

import numpy as np
import matplotlib.pyplot as plt

def relu(x):
    return np.maximum(0, x)

def softmax(z):
    c = np.max(z)
    exp_z = np.exp(z - c)
    sum_exp_z = np.sum(exp_z)
    y = exp_z / sum_exp_z
    return y

def a(x, W, b):
    a = np.dot(x, W) + b
    return a

def init_network():
    network = {}
    network['W1'] = np.array([[0.1, 0.3, 0.5], [0.2, 0.4, 0.6]])
    network['b1'] = np.array([0.1, 0.2, 0.3])
    network['W2'] = np.array([[0.1, 0.4], [0.2, 0.5], [0.3, 0.6]])
    network['b2'] = np.array([0.1, 0.2])
    network['W3'] = np.array([[0.1, 0.3], [0.2, 0.4]])
    network['b3'] = np.array([0.1, 0.2])

    return network

def forward(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    # a1 = np.dot(x, W1) + b1
    a1 = a(x, W1, b1)
    z1 = relu(a1)
    # a2 = np.dot(z1, W2) + b2
    a2 = a(z1, W2, b2)
    z2 =relu(a2)
    # a3 = np.dot(z2, W3) + b3
    a3 = a(z2, W3, b3)
    y = softmax(a3)

    return y

network = init_network()
x = np.array([1.0, 0.5])
y = forward(network,x)
print("y = {}".format(y))

print("sum(y) = {}".format(np.sum(y)))

重要なパラメータ(重みやバイアス)を人間が与えています。つまり、このモデルの深刻な問題点はモデル自身がデータを学習していないことにあります。これではいくら複雑な表現ができて回帰やクラス分類に適した出力をしていても実際のデータに対応することができません(人間が頑張ってデータを学習すればその限りではありませんが、もはや機械学習を用いる意味がなくなってしまいます)。

損失関数(loss function)は、モデルの予測結果と実際の正解のデータとの誤差を測定することができ、モデルが教師データを学習する過程でこの誤差を減らすように重みを最適化することで、パラメータを自動で決定することができます。

■二乗和誤差

二乗和誤差(SSE:sum of squared error)とは、モデルの予測値と実際の値を誤差の二乗の合計を取ったもので、平均二乗誤差が平均を取るのに対して、二乗和誤差はこれをしません。

$$ E = \frac{1}{2} \displaystyle \sum_{k} (y_k – t_k)^2 $$

二乗を微分すると必ず2が出てくるので、簡単にするために1/2をかけます(必須ではありません)。二乗することで、誤差が正の値に変換され、より増大させて評価することができます。

実際に最適化するには逆伝播などの手法を導入しないといけないため、ここではその値を見るだけに留めておきます。

import numpy as np

def sum_squared_error(y, t):
    return 0.5 * np.sum((y - t) ** 2)

t = np.array([2.0])
y = np.array([1.5])

sse = sum_squared_error(y, t)
print("sum of squared error = {:.3f}".format(sse))

予測値が1.5、正解の値が2.0と仮置きして、一度だけ二乗和誤差を計算してその値を出力しています。

誤差をより近づけるような値を与えたら二乗和誤差の値はどう変わるでしょうか。

import numpy as np

def sum_squared_error(y, t):
    return 0.5 * np.sum((y - t) ** 2)

t = np.array([2.0])
y = np.array([1.8])

sse = sum_squared_error(y, t)
print("sum of squared error = {:.3f}".format(sse))

二乗和誤差による乖離の指標は小さくなりました。(実態は力業ですが)より上手く予測できたという評価です。

二乗和誤差は一組の誤差に対する評価には適していますが、クラス分類タスクのような確率分布全体で評価するような場合にはあまり向きません。そのため回帰タスクで用いられることが多いです。(もちろんクラス分類に用いること自体は可能で、参考図書ではクラス分類に対して使っているので注意してください。)

■交差エントロピー誤差

交差エントロピー誤差(cross entropy error)は、主にクラス分類タスクに用いられる損失関数で、確率分布間の距離を測ることで予測確率と実際のラベルとの誤差を指標化してくれます。

$$ E =\ – \displaystyle \sum_{k} t_k \log y_k $$

こちらもその挙動を見てみることにします。

import numpy as np

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))

t = np.array([1, 0])
y = np.array([0.9, 0.1])

loss = cross_entropy_error(y, t)
print("cross entropy error = {:.3f}".format(loss))

これは上手く予測手出来ている場合ですが、もっと確信度が低くなるような結果ならどうなるでしょうか。

import numpy as np

def cross_entropy_error(y, t):
    delta = 1e-7
    return -np.sum(t * np.log(y + delta))

t = np.array([1, 0])
y = np.array([0.6, 0.4])

loss = cross_entropy_error(y, t)
print("cross entropy error = {:.3f}".format(loss))

交差エントロピー誤差の値が大きくなっていることが分かります。

■おわりに

今回は損失関数を導入し、回帰とクラス分類で使われる二乗和誤差と交差エントロピー誤差の挙動を確認してみました。

前述の通り、これをニューラルネットワークの学習に用いるには更なる手法が必要になります。これがないと損失関数の効果を実感することはできませんが、ニューラルネットワークにとってはこの指標が非常に重要で、今自分がどのくらい学習できているのか、学習(重み)をどちらの方向にずらせばモデルの精度が良くなるか等の重要な指標になります。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
  3. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  4. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です