如何使用神经网络识别图像?



作者:论智

 

讲一个最简单的例子吧。用神经网络识别MNIST数据集中的手写数字图像。

首先介绍下MNIST数据集的基本情况。MNIST数据集包括60000张手写数字图像,每个图像的大小为 28x28 像素。这些图像都是标注好了的。下面是其中的一些样本。

然后,我们需要搭建一个神经网络来识别MNIST数据集中的手写数字图像。

怎么搭?

大致上,我们可以认为神经网络做的事是拟合函数。对于函数而言,我们常常通过描述函数的输入和输出是什么样的来描述一个函数。那么,我们可以应用同样的思路到神经网络上。

先考虑输出。显然,输出是数字的分类,也就是,神经网络识别这个数字是0-9这10个数字中的一个。那么,神经网络的输出层应该由10个神经元组成,每个神经元代表0-9之间的一个数字分类。

再考虑输入。前面我们已经提到过,图像的大小是 28x28 像素。28x28 = 784。所以,换句话说,每张图像由784个像素组成。那么输入层我们可以简单粗暴地设置为784个神经元组成的神经网络层。

然后,我们再把输入层和输出层直接连起来。

(图片来源:ml4a.github.io,许可:GPL-2.0)

这网络够简单、够粗暴吧?

这个看上去是儿戏吧?输入图像有784个像素,所以用784个神经元;输出结果是10个分类,所以用10个神经元。然后直接把输入层和输出层连接起来。这也太简单粗暴了呀!

然而,神经网络就是这么神奇!通过训练,这样简单粗暴的网络就能学会分类手写数字。

当我们训练好网络之后,可视化一下连接到第一个输出神经元(数字分类0)的网络。

(图片来源:ml4a.github.io,许可:GPL-2.0)

我们可视化上面的权重w_1、w_2、...、w_784,然后给权重分配颜色,最低的权重是白色,最高的权重是黑色,用灰度的深浅表示权重的大小:

(图片来源:ml4a.github.io,许可:GPL-2.0)

右面这图是不是有点像数字0?

同理,我们可视化所有10个数字分类收到的权重。

(图片来源:ml4a.github.io,许可:GPL-2.0)

仔细看看上面的图,像不像0-9这10个数字?

当然,实际使用的神经网络要比这个复杂很多,相应的,识别的精确度也要高很多。但是我希望这个简单粗暴的神经网络能够让你直观地理解神经网络是如何分类识别图像的。


0