畳み込みニューラルネットワーク その4


AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day72


経緯についてはこちらをご参照ください。



■本日の進捗

  • プーリング層を理解


■はじめに

今回も「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装(オライリー・ジャパン)」で、深層学習を学んでいきます。

今回は畳み込みニューラルネットワークの構成要素のもうひとつの層であるプーリング層を定義していきます。

■プーリング層

プーリング層(Pooling layer)とは、取り扱う特徴マップのチャンネル数は変えずに高さと幅のサイズをリダクションし、計算負荷を削減したり過学習抑制のために用いられます。プーリング層自体が何かを学習するわけではないので必ずしも必須ではありませんが、情報量の多いデータを扱うCNNにとって重要な役割を担う層です。

プーリング手法には最大値を取るMax Poolingや平均値を取るAverage Pooling、特徴マップ全体でひとつのスカラー値を算出するGlobal Pooling等があります。

前述した特徴マップのリダクションの他にも、入力データを上記手法で変換していくので、ストライドなどによるデータの空間的なずれに対してロバスト性が向上することも大きなメリットです。

■プーリング層の順伝播

Max Poolingでは、各ウインドウ内での最大値を抽出します。

$$ \boldsymbol{Y}_{i, j} = \mathrm{max} \boldsymbol{X}_{m, n} $$

処理自体はとても単純ですが、畳み込み層と同様に膨大な行列処理になる可能性が高いのでここでもim2col関数を用いた実装を考えます。

まずはプーリングのウインドウサイズ(pool_h, pool_w)から、プーリング後の出力サイズを求めておきます。これはこれまで何度も見てきたように次のように記述できます。

out_h = int(1 + (H - self.pool_h) / self.stride)
out_w = int(1 + (W - self.pool_w) / self.stride)

im2col関数で2次元配列に変換した結果をcolに格納します。ここで各プーリングによるウインドウの要素が1列に展開され、全てのウインドウをまとめて行列として表現します。

変換後は自動で算出した形状(引数の-1)に、ひとつのウインドウに含まれる要素数(pool_h × pool_w)で並び替えを行います。

col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w)

最後に各ウインドウ内の最大値を取得して、(データ数、出力高さ、出力幅、チャンネル数)の形状を、入力と同じ構成になるように(データ数、チャンネル数、出力高さ、出力幅)に入れ替えたら終わりです。

arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)

self.x = x
self.arg_max = arg_max



■プーリング層の逆伝播

逆伝播の場合もとてもシンプルですが、順伝播においては最大値に選ばれなかった特徴量は勾配として伝播していないので、逆伝播においてもこれを考慮しなくてはいけません。

$$ \frac{\partial L}{\partial \boldsymbol{X}_{m, n}} = \begin{eqnarray} \left\{ \begin{array}{l} \frac{\partial L}{\partial \boldsymbol{Y}_{i, j}} \ \ \ \ \ \mathrm{when} \ \boldsymbol{X}_{m, n} \ \mathrm{is} \ \mathrm{max} \\ 0 \end{array} \right. \end{eqnarray}$$

まずは前層の勾配を入力データ(順伝播での出力)と同じ形状にします。

dout = dout.transpose(0, 2, 3, 1)

次に、ウインドウ内の要素数を求めて、各ウインド内の要素ごとに勾配を格納するためのdmaxを初期化しておきます。

pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))

np.arange(self.arg_max.size)でarg_maxのインデックスを取得するための配列を生成し、self.arg_max.flatten()で順伝播で得られた最大値のインデックスを平準化したら、前層の勾配を1次元配列に変換して順伝播で最大値を持っていた場所にのみ勾配を格納します。

最後にこの勾配を、並び替え済みのdoutにpool_sizeを加えた(データ数、出力高さ、出力幅、チャンネル数、プーリングウインドウサイズ)に再配置します。

dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,)) 

col2im関数を使って元の入力データ形状に戻すために、dmaxを2次元配列のdcolに変換したら、col2im関数で変換し元の入力データの勾配であるdxに格納します。

dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)



■プーリング層クラス

以上をまとめると、プーリング層は下記のように記述できます。

class Pooling:
    def __init__(self, pool_h, pool_w, stride=2, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad
        
        self.x = None
        self.arg_max = None

    def forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)

        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
        col = col.reshape(-1, self.pool_h*self.pool_w)

        arg_max = np.argmax(col, axis=1)
        out = np.max(col, axis=1)
        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)

        self.x = x
        self.arg_max = arg_max

        return out

    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)
        
        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,)) 
        
        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
        
        return dx



■おわりに

今回はプーリング層を定義してみました。Max Poolingは挙動が理解しやすく計算コストも低く、特徴量を上手く捉えやすいので多くの場合に選択される手法ですが、画像データなどにノイズが乗っていて外れ値(特に最大側への外れ値)が多い場合などには、そのノイズを重要な特徴量として扱う可能性があるので注意が必要です。

事前に正規化を用いたり、Average Poolingを試すことでより上手くデータに適応できるかもしれません。

■参考文献

  1. Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
  2. 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
  3. ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
  4. API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html
  5. PyTorch documentation. pytorch.org. https://pytorch.org/docs/stable/index.html
  6. Keiron O’Shea, Ryan Nash. An Introduction to Convolutional Neural Networks. https://ar5iv.labs.arxiv.org/html/1511.08458


コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です