基于MXNet的入门级CNN visualization。嗯TF无脑黑,MXNet & PyTorch一生推。
具体的可以去看这里的ipython notebook ,可以直接跑的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import numpy as npimport osimport mxnet as mxfrom mxnet import gluonfrom mxnet import imagefrom mxnet import ndfrom mxnet import initfrom mxnet import autogradfrom mxnet.gluon.data import visionfrom mxnet.gluon import nnfrom mxnet.gluon.model_zoo import vision as modelsfrom PIL import Imagefrom jupyterthemes import jtplotjtplot.style(theme='onedork' , grid=False )
1 2 %matplotlib inline import matplotlib.pyplot as plt
修改模型下载源,对国内下载速度友好一点
1 os.environ['MXNET_GLUON_REPO' ]='https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/'
加载三个预训练好的模型
1 2 3 vgg19 = models.vgg19(pretrained=True ) vgg16 = models.vgg16(pretrained=True ) resnet152 = models.resnet152_v1(pretrained=True )
这里读取一张妹子图片
1 data = nd.array(np.asarray(Image.open ('000000.jpg' )))
1 plt.imshow(np.asarray(Image.open ('000000.jpg' )))
这里要注意,预训练好的模型输入的图片大小是\(224 \times 224\) 的,因此这里对图片重新进行缩放,另外因为MXNet的输入格式有需求,所以我们这里也做了reshape的动作。
1 2 3 4 5 6 data = mx.image.imresize(data, 224 , 224 ) data = nd.transpose(data, (2 , 0 , 1 )) data = data.astype(np.float32)/255 data = data.reshape((1 ,)+data.shape) print (data.shape)
(1, 3, 224, 224)
接下来要画的是Saliency Map。可以参考这篇论文 。实际上就是看哪个位置的梯度最大。
1 2 3 4 data.attach_grad() with autograd.record(): out = vgg19(data) out.backward()
1 2 3 4 5 6 7 8 plt.figure(figsize=(15 , 5 )) plt.subplot(1 , 3 , 1 ) plt.imshow((data[0 ].asnumpy().transpose(1 , 2 , 0 )*255 ).astype(np.uint8)) plt.subplot(1 , 3 , 2 ) plt.imshow(np.abs (data.grad.asnumpy()[0 ]).max (axis=0 ), cmap='gray' ) plt.subplot(1 , 3 , 3 ) plt.imshow(np.abs (data.grad.asnumpy()[0 ]).max (axis=0 ), cmap=plt.cm.jet)
1 2 3 4 data.attach_grad() with autograd.record(): out = resnet152(data) out.backward()
1 2 3 4 5 6 7 8 plt.figure(figsize=(15 , 5 )) plt.subplot(1 , 3 , 1 ) plt.imshow((data[0 ].asnumpy().transpose(1 , 2 , 0 )*255 ).astype(np.uint8)) plt.subplot(1 , 3 , 2 ) plt.imshow(np.abs (data.grad.asnumpy()[0 ]).max (axis=0 ), cmap='gray' ) plt.subplot(1 , 3 , 3 ) plt.imshow(np.abs (data.grad.asnumpy()[0 ]).max (axis=0 ), cmap=plt.cm.jet)
这里有个很有意思的现象,VGG-19偏向找人的头部区域,而ResNet则是找到了腿。另外可以多试验几张图,看看效果。一般试下来VGG偏向把轮廓弄出来,ResNet就会找到各种奇奇怪怪的地方去。但是ResNet效果很好,暂时不能理解为什么。
接下来我们把filter画出来,先看一下VGG的结构。
VGG(
(features): HybridSequential(
(0): Conv2D(3 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Activation(relu)
(2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): Activation(relu)
(4): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(5): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): Activation(relu)
(7): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): Activation(relu)
(9): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(10): Conv2D(128 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): Activation(relu)
(12): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): Activation(relu)
(14): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): Activation(relu)
(16): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): Activation(relu)
(18): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(19): Conv2D(256 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): Activation(relu)
(21): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): Activation(relu)
(23): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): Activation(relu)
(25): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): Activation(relu)
(27): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(28): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): Activation(relu)
(30): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): Activation(relu)
(32): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): Activation(relu)
(34): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): Activation(relu)
(36): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
(37): Dense(25088 -> 4096, Activation(relu))
(38): Dropout(p = 0.5)
(39): Dense(4096 -> 4096, Activation(relu))
(40): Dropout(p = 0.5)
)
(https://raw.githubusercontent.com/SamaelChen/samaelchen.github.io/hexo/images/blog/output): Dense(4096 -> 1000, linear)
)
把每个卷积层的权重打出来。
1 2 3 4 for i in vgg19.features: if isinstance (i, nn.Conv2D): j = i.weight.data() print (i.weight.data()[0 ])
将最后一层卷积层的第一个filter画出来,然而,完全看不出到底这个filter能起到什么效果。
1 plt.imshow(np.abs (j[0 ][0 ].asnumpy()))
取第一层卷积层出来
1 2 3 4 5 for i in vgg19.features: if isinstance (i, nn.Conv2D): j = i print (i.weight.data()[0 ]) break
[[[-0.05347426 -0.04925704 -0.06794177]
[ 0.01531445 0.04506842 0.0021444 ]
[ 0.03622622 0.01999945 0.01986402]]
[[ 0.01701478 0.05540261 -0.0062293 ]
[ 0.14164735 0.22705214 0.13758276]
[ 0.12000094 0.2002953 0.09211431]]
[[-0.04488515 0.01267995 -0.01449722]
[ 0.05974238 0.13954678 0.05410246]
[-0.00096141 0.058304 -0.02966315]]]
<NDArray 3x3x3 @cpu(0)>
看一张图片进入第一个卷积层后会得到什么样的结果
1 2 3 4 for num in range (64 ): fig = plt.figure() ax = fig.add_subplot(1 , 1 , 1 ) ax.imshow(np.abs (i(data)[0 ].asnumpy())[num])
/home/samael/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py:528: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
max_open_warning, RuntimeWarning)
然后,嗯,很神奇,第一个卷积的64个通道的效果都在上面。中间有一些看上去还有点像浮雕的效果。某一张嘴唇位置及其显眼。
1 sample = np.random.uniform(150 , 180 , (224 , 224 , 3 ))
<matplotlib.image.AxesImage at 0x7f4ec0496e10>
这里生成一张充满噪声的点,再来看看每个filter在做什么。
1 2 3 4 5 6 7 8 9 10 mean = [0.485 , 0.456 , 0.406 ] std = [0.229 , 0.224 , 0.225 ] for channel in range (3 ): sample[:, :, channel] /= 255 sample[:, :, channel] -= mean[channel] sample[:, :, channel] /= std[channel] sample = sample.reshape((1 ,)+sample.shape) sample = sample.transpose(0 , 3 , 1 , 2 ) sample = nd.array(sample)
1 2 3 4 for num in range (64 ): fig = plt.figure() ax = fig.add_subplot(1 , 1 , 1 ) ax.imshow(np.abs (i(sample)[0 ].asnumpy())[num])
/home/samael/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py:528: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
max_open_warning, RuntimeWarning)
入门级别的CNN visualization基本上就这些了。网上没找到MXNet做这个的教程,只能自己摸索了。还好gluon跟pytorch接口很像,可以照着MXNet的源码,再借鉴pytorch的教程慢慢摸索。