AIって結局何なのかよく分からないので、とりあえず100日間勉強してみた Day71
経緯についてはこちらをご参照ください。
■本日の進捗
- 畳み込み層を理解
■はじめに
今回も「ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装(オライリー・ジャパン)」で、深層学習を学んでいきます。
今回は畳み込みニューラルネットワークの処理の中から、その肝となる畳み込み層を構築していきます。
■畳み込み層
これまでに実装してきたim2col関数やcol2im関数を用いて畳み込みニューラルネットワークの根幹を担う畳み込み層を実装していきます。
入力特徴マップに対して、画像データの特徴量を抽出するための重みを持つためのフィルター(カーネル)をストライド量に合わせて行列全体にスライドさせて内積を行います。
これにより特徴マップの空間的情報を保持したままニューラルネットワークで学習を行うことが可能になります。
■畳み込み層の順伝播
順伝播では、入力データxに対してフィルターWで畳み込み演算を実施し、バイアスbを加算して出力yを計算します。
$$ y_{i, j, k} = \sum_{c=1}^C \sum_{h=1}^{FH} \sum_{w=1}^{FW} x_{c, i+h, j+w} \cdot W_{k, c, h, w} + b_k $$
Cはチャンネル数、FH, FWはフィルターの高さと幅、bkはバイアスで、i, jにはそれぞれstrideが乗っています。
まずは下記の式で算出可能な出力サイズを代入しておきます。
$$ H_{out} = \frac{H_{in} + 2P \ – H_{filter}}{S} + 1 $$
$$ W_{out} = \frac{W_{in} + 2P \ – W_{filter}}{S} + 1 $$
FN, C, FH, FW = self.W.shape N, C, H, W = x.shape out_h = 1 + int((H + 2*self.pad - FH) / self.stride) out_w = 1 + int((W + 2*self.pad - FW) / self.stride)
次にim2col関数を用いて入力データとフィルターを2次元のcol配列に変換します。col配列の形状は(データ数×出力高さ×出力幅, C×フィルター高さ×フィルター幅)のように展開されています。フィルターをreshapeすることで(フィルター数、C×フィルター高さ、フィルター幅)に変換して各フィルターごとに一列に並んだ形状になります。
col = im2col(x, FH, FW, self.stride, self.pad) col_W = self.W.reshape(FN, -1).T
これらの配列を用いて畳み込み演算を行います。転置によって出力形状を(データ数、フィルター数、出力高さ、出力幅)に並び替えます。
out = np.dot(col, col_W) + self.b out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
最後に結果を格納して畳み込み演算の結果を返します。
self.x = x self.col = col self.col_w = col_W return out
■畳み込み演算の逆伝播
逆伝播の場合も求めて行きます。
損失のバイアスに関する勾配は、連鎖律を用いて、
$$ \frac{\partial L}{\partial b} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial b} $$
$$ \frac{\partial L}{\partial b} = \sum_{i, j} \frac{\partial L}{\partial y} $$
同様に重みに関する勾配も、
$$ \frac{\partial L}{\partial W} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial W} $$
$$ \frac{\partial L}{\partial W} = \sum_{i, j} \frac{\partial L}{\partial y} \cdot x $$
また、入力データxの各位置は、k個のフィルターによって畳み込まれた結果として出力に影響を与えるため、入力に関する勾配は、
$$ \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} $$
$$ \frac{\partial L}{\partial x} = \sum_{k} \sum_{i, j} \frac{\partial L}{\partial y} \cdot W $$
これを実装していきます。
まずは順伝播の出力に対する勾配のデータ形状を、(データ数、高さ、幅、出力チャンネル数)に並ぶように入れ替え、2次元配列に変換します。reshapeの引数-1は自動で算出される次元数です。
dout = dout.transpose(0,2,3,1).reshape(-1, FN)
先ほどのバイアスの勾配は、全データに対して合計を取り、出力チャンネルごとにひとつの値を持ちます。
self.db = np.sum(dout, axis=0)
重みの勾配は、im2col関数によって展開されたcol配列に対して出力勾配との内積で求められることを確認していました。
また、重みの勾配を元のフィルター形状である(フィルター数、チャンネル数、高さ、幅)に戻しておきます。
self.dW = np.dot(self.col.T, dout) self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
入力に関する勾配も、出力勾配を用いた内積で求めます。
dcol = np.dot(dout, self.col_W.T)
最後に、col2im関数でim2col関数の逆変換を行って、元の入力データ形状に戻しておきます。
dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)
■畳み込み層クラス
以上をまとめると、下記のように記述できます。
class Convolution: def __init__(self, W, b, stride, pad): self.W = W self.b = b self.stride = stride self.pad = pad self.x = None self.col = None self.col_W = None self.dW = None self.db = None def forward(self, x): FN, C, FH, FW = self.W.shape N, C, H, W = x.shape out_h = 1 + int((H + 2*self.pad - FH) / self.stride) out_w = 1 + int((W + 2*self.pad - FW) / self.stride) col = im2col(x, FH, FW, self.stride, self.pad) col_W = self.W.reshape(FN, -1).T out = np.dot(col, col_W) + self.b out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2) self.x = x self.col = col self.col_w = col_W return out def backward(self, dout): FN, C, FH, FW = self.W.shape dout = dout.transpose(0,2,3,1).reshape(-1, FN) self.db = np.sum(dout, axis=0) self.dW = np.dot(self.col.T, dout) self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW) dcol = np.dot(dout, self.col_W.T) dx = col2im(dcol, self.x.shepe, FH, FW, self.stride, self.pad) return dx
■おわりに
これまでに実装してきた関数を用いて畳み込み層をクラスとして定義することができました。これでこれまでのニューラルネットワークの各層と同様に、好きなタイミングで好きなだけ畳み込み層を呼ぶことができるようになりました。
■参考文献
- Andreas C. Muller, Sarah Guido. Pythonではじめる機械学習. 中田 秀基 訳. オライリー・ジャパン. 2017. 392p.
- 斎藤 康毅. ゼロから作るDeep Learning Pythonで学ぶディープラーニングの理論と実装. オライリー・ジャパン. 2016. 320p.
- ChatGPT. 4o mini. OpenAI. 2024. https://chatgpt.com/
- API Reference. scikit-learn.org. https://scikit-learn.org/stable/api/index.html
- PyTorch documentation. pytorch.org. https://pytorch.org/docs/stable/index.html
- Keiron O’Shea, Ryan Nash. An Introduction to Convolutional Neural Networks. https://ar5iv.labs.arxiv.org/html/1511.08458