多クラス分類の評価基準


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day38


経緯についてはこちらをご参照ください。



■本日の進捗

●多クラスの評価基準を理解

■はじめに

引き続き「Pythonではじめる機械学習(オライリー・ジャパン)」で学んでいきます。

前回までは2クラス分類における評価基準を混同行列から適合率、再現率、f-値といった指標を導入し、ROCカーブやAUCといった評価手法を学んできました。今回はこれを多クラス分類タスクの場合に発展させていきます。

■多クラス分類の評価基準

多クラス分類タスクの場合も、2クラス分類タスクでの評価基準を拡張することで適用できます。

これまでの評価手法に対して一対他(one-vs-rest:OvR)アプローチやマルチクラスROC(one-vs-one:OvO)アプローチを適用することでAUCの算出が可能です。

OvRアプローチではクラスの数だけROCカーブが描けるので、それぞれに対してAUCを計算し、平均値を見ることで評価できます。

scikit-learnでは、roc_auc_scoreのmulti_class引数にovrを渡すことで簡単に実装可能です。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import cross_val_predict, cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import roc_auc_score

digits = load_digits()
X = digits.data
y = digits.target

model = OneVsRestClassifier(LogisticRegression(random_state=8, max_iter=1000))

y_prob = cross_val_predict(model, X, y, cv=5, method='predict_proba')

n_classes = len(np.unique(y))
roc_auc_per_class = {}

for i in range(n_classes):
    roc_auc_per_class[i] = roc_auc_score((y == i).astype(int), y_prob[:, i])

for i, auc in roc_auc_per_class.items():
    print(f'Class {i} AUC = {auc:.2f}')

roc_auc_overall = roc_auc_score(y, y_prob, multi_class='ovr')
print('Overall AUC (OvR) = {:.2f}'.format(roc_auc_overall))

accuracy = cross_val_score(model, X, y, cv=5, scoring='accuracy')
print('Accuracy = {:.2f}'.format(accuracy.mean()))

手書きの0から9までの数字で構成された多クラス分類タスクであるdigitsデータセットにロジスティック回帰で学習させて、OvRアプローチでAUCを算出してみました。

なお、今回のモデル構築は出力自体はちゃんと出してくれるのですが、なぜかOneVsRestClassifierを使うよう注意されたのでそのようにしています。

■おわりに

今回までクラス分類タスクにおける機械学習モデルの評価基準を考えてきましたが、回帰タスクについてはまだ触れていません。しかし、データのリーク等を除いて回帰タスクではR2スコア等の評価手法でデータを上手く説明できることが多いです。

R2スコアも過剰適合に対しては少々評価が難しいところがあって、平均絶対誤差(MAE)、平均二乗誤差(MSE)などを検討するのも良いでしょう。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  3. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です