8.4 CIFAR-10图像识别案例的TensorFlow实现

在这一讲中,我们将一起来完成一个案例,这个案例是使用卷积神经网络对CIFAR-10数据集进行图像分类。

> CIFAR-10数据集

CIFAR-10是由Alex Krizhevsky, Vinod Nair和Geoffrey Hinton收集而来的用于图像识别的数据集,它一共有十个分类,每个分类有6000张32×32大小的彩色图像,所以一共有60000张图像,其中50000张用于训练,10000张用于测试。

图8-43

我们可以在下面这个网址查看有关CIFAR-10数据集的介绍:

https://www.cs.toronoto.edu/~kriz/cifar.htmlarrow-up-right

那么这个CIFAR-10数据集跟MNIST数据集相比有什么不同呢?

CIFAR-10是RGB三通道图像,MNIST数据集是灰度图;CIFAR-10的图像尺寸为32×32,MNIST数据集为28×28;CIFAR-10数据集它的色彩和噪点比较多,同一个分类,比如说,在卡车这个分类里,卡车的角度、颜色和大小都不一样,所以CIFAR-10图像识别难度比MNIST数据集要高得多。

> 下载CIFAR-10数据集

图8-44

第一次执行上图程序的时候,程序会检查你是否已经下载过CIFAR-10数据集,如果它没有找到这个文件的话,就会自动下载文件并进行解压;如果已经下载,它会显示提示信息“Data file already exists.”、“Directory already exists.”。

CIFAR 10数据集文件的大小为163M。

如果你的网络网速不是很快,在后面通过代码下载可能会超时。如果遇到这种情况,可以先直接通过以下网址下载:

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gzarrow-up-right

下载的文件名是cifar-10-python.tar.gz,该先在Jupty代码目录下建立data子目录,然后把下载的数据文件放到data目录。

> 导入CIFAR-10数据集

在下载和解压之后,我们需要读取训练数据和测试数据。

在这里,我们定义了两个函数:load_CIFAR_batch( )和load_CIFAR_data( )。

load_CIFAR_batch( )读取一个批次的样本,load_CIFAR_data( )实现数据集的完整读取。

图8-45
图8-46

我们每次读取10000条数据,所以load_CIFAR_data( )通过一个for循环语句来循环地调用load_CIFAR_batch( ),从而获取完整的数据集,最后返回的是训练集的图像和标签以及测试集的图像和标签。

运行之后,我们可以看到data已经loading完毕啦。

图8-47

> 显示数据集信息

我们运用了“.shape”方法来查看训练集和测试集的信息。

图8-48

训练集有50000条数据,测试集有10000条数据,图像的尺寸为32×32,通道为RGB三通道。

> 查看单项image和label

首先,我们导入matplotlib,查看一下里面的第七个样本(因为下标从0开始,所以6代表第七个样本)。

图8-49

我们可以看到,上图中显示了CIFAR-10数据集中的第七个样本(一个模糊的孔雀)。我们还可以查看它的标签,标签值为2,这个2对应的就是第三类,也就是bird类。

具体信息可以在网站http://www.cs.toronto.edu/~kriz/cifar.htmlarrow-up-right上查看。

> 查看多项images和label

当然,我们也可以一次查看多张图像及其标签:

图8-50

首先,我们需要定义一个标签字典,每一个数字代表了对应图像类别的名称。然后,我们定义了一个显示图像数据及其对应标签的函数。这里的输入参数有图像(我们需要分类的)、标签(实际上的类别)、预测的类别,num是我们一次想要展示的图片数量,在这里是10张。

运行之后就会得到下图的结果:

图8-51

这里我们指定的是看前十张图片,图片上方的0是第一张图片,1是第二张图片,数字右边的单词是它们所对应的真正类别,由于我们的输入需要labels和prediction,但是我们现在还没有真正的去做预测,因此prediction是个空列表。

> 数据预处理

接下来,我们要对数据进行必要的预处理,这个数据包括图像数据和标签数据。

> 图像数据预处理

我们怎么来查看图像数据信息呢?

我们可以查看第一个训练样本的第一个像素点:

图8-52

由于图像是三通道的,所以59、62、63这三个数分别代表了图像的第一个像素点在RGB三个通道上的像素值。

然后我们对图像进行数字标准化:

图8-53

因为图像的数字标准化可以提高模型的准确率。在没有进行标准化或归一化之前,图像的像素值是0-255,如果我们想对它进行归一的话,最简单的做法就是除以255。下图就反映了经过处理后的结果:

图8-54

我们可以看到,经过处理后的数值全都在0-1之间,说明我们的数据已经预处理完毕了。

> 标签数据预处理

图8-55

对于CIFAR-10数据集,它的label是0-9

比如船这个分类,它对应的label是8,我们希望通过独热编码来表示它的分类以下是将类别进行独热编码的程序:

图8-56

经过独热编码之后的shape是50000和10:

图8-57

在没有进行独热编码之前,对于训练数据集,前五个样本所对应的标签是6、9、9、4、1:

图8-58

在进行独热编码之后,它们所对应的标签变成了下图的样子:

图8-59

第一个数据原来的真实值为6,经过独热编码转换之后变成了0或1的组合,只有下标为6的位置所对应的值为1,其他都为0。

> 建立CIFAR-10图像分类模型

我们将要建立的卷积神经网络的结构呢,如下图所示:

图8-60

在这个网络结构里,图像的特征提取是通过卷积层1、降采样层1、卷积层2以及降采样层2处理之后得到的。

全连接神经网络是由全连接层、输出层所组成的网络结构。

首先,我们导入所需要的库:

图8-61

> 定义共享函数

图8-62

这个共享函数包括权值、偏置、二维卷积和池化函数。

这里的卷积函数和池化函数,我们就是调用了TensorFlow自带的二维卷积和池化,只不过我们对它的参数进行了指定。比如padding为“SAME”,因为我们希望卷积之后,图像的大小不变。而ksize是2,因为我们希望池化之后,图像的大小变为原来的四分之一,也就是它的宽和高分别变为原来的二分之一。

> 定义网络结构

图8-63
图8-64

这里的网络结构跟图8-60是一一对应的。

首先是输入层,输入层是32×32的图像,RGB三通道,所以这里的shape是None、32、32、3。None是指我们不限定一个批次里样本的数量。

第一个卷积层的输入通道为3、输出通道为32。这里的weight([3,3,3,32]),第一个3是卷积核的宽,第二个3是卷积核的高,第三个3是输入通道数量,32是输出通道数量。然后我们通过卷积核,对图像进行卷积之后,加上一个偏置,得到第一个卷积层的输出conv_1。我们对它进行一个非线性激活,这里采用的非线性激活函数是tf.nn.relu()。

第一个池化层我们采用了最大池化,大家也可以尝试把它换成均值池化,看看结果有什么不一样。

第二个卷积层的操作和第一个卷积层的操作是类似的。

第二个池化层的操作跟第一个池化层的操作也是类似的。

在两个卷积层和池化层之后,就是全连接层。我们先把64个8×8的图像转换为一维向量,这64个8×8的图像就是第二个池化层的输出,转换后一维向量的长度是64×8×8=4096。128指的是这个全连接层神经元的个数,我们也可以调整这个数字。“h = tf.nn.relu(tf.matmul(flat, W3) + b3)”这条语句跟我们之前学习的全连接神经网络是一样的,每一个神经元都和前面的4096个像素点进行全连接。然后我们加入了h_dropout来防止过拟合。

最后是输出层,输出层共有10个神经元,对应到0-9这10个类别。

> 构建模型

在之前我们有提到过:全连接神经网络构建模型、训练模型的方式同样适用于卷积神经网络。所以在这里定义占位符、定义损失函数以及选择优化器这些操作都跟全连接神经网络是一致的,包括准确率的定义。

图8-65

> 定义准确率

图8-66

> 训练模型

接下来,我们对模型进行训练。

> 启动会话

图8-67

在启动会话之前,我们需要指定它迭代的次数、批量的样本大小等等。

> 断点续训

我们知道,程序的训练,尤其对于大规模数据集或者复杂的网络,它的训练时间非常长,往往需要数个小时甚至数天,有时还可能会因为某些原因导致了计算机宕机。这样的话,前面的训练就会前功尽弃。解决的方案呢,就是增加一个断点续训的机制,每次程序执行完训练之后,将模型的权重保存一下,下次程序在执行训练之前,我们先加载这个模型的权重,再继续训练就可以了。

介绍到这里,大家可能回想起这个断点续训跟我们在MNIST案例中介绍的模型的存储和加载很类似。

首先我们定义一个存储路径,这里就用当前目录下的"CIFAR10_log/"目录。当这个目录不存在的时候,我们就会创建一个。

由于我们已经定义完所有的变量了,所以我们可以调用tf.train.Saver()来保存和提取变量。这个变量包含了权重以及其他在程序中定义的变量。

再接下来就是加载模型。如果存储路径中已经有训练好的模型文件,那我们可以用saver.restore()来加载所有的参数,然后就可以直接使用模型进行预测,或者接着继续训练了。

在这里,我们取了“断点续训”这个名字,是因为我们除了希望保存和加载模型之外,还希望知道断点在哪里、我们是从哪里开始继续训练的。我们在启动会话中定义了一个不可训练的变量epoch,然后在断点续训的时候,通过sess.run(epoch)获得它的值,从而我们就可以知道,我们是从第几轮开始继续迭代训练的。

图8-68

上图说明,在我们设置的检查点目录中,目前还没有已经训练好的模型,没有checkpoint文件,所以我们现在的模型是从头开始训练的。

> 迭代训练

图8-69

首先,我们定义一个get_train_batch()函数来批量获取训练数据,它返回的是经过归一化的图像数据以及标签数据。

接下来是for循环。然后我们通过sess.run()来获取模型的损失值和准确率。

25次迭代训练完成后,损失函数值约等于1.84,准确率是62%。

图8-70

如果你觉得62%的准确率还不够满意的话,可以通过以下方式去提升准确率:

图8-71

> 可视化损失值

图8-72
图8-73

在训练的过程中,损失值越来越低,也就是说,我们的训练是有效的。

> 可视化准确率

图8-74
图8-75

准确率的变化趋势是越来越高的。从上面损失值的图片可以看到,它并没有处于一个收敛的状态,因此准确率也还有上升的空间。

> 评估模型及预测

现在,我们已经建立好了模型并且完成了模型的训练,当你觉得训练的准确率已经能够达到你所期待的准确率的时候,你就可以用这个模型来进行预测了。

在CIFAR数据集上,对卷积神经网络进行模型评估及预测跟我们在MNIST数据集上进行的模型评估是一样的。

> 计算测试集上的准确率

图8-76

测试的准确率跟我们刚才训练的准确率是差不多的。

> 利用模型进行预测

图8-77

> 可视化预测结果

图8-78

除了第二个图片模型把船当成了汽车,其他图片模型都分类正确。

> 示例代码

Last updated

Was this helpful?