6.2 分类模型构建与训练

> 模型构建

> 定义待输入数据的占位符

在训练模型的时候,需要带标签的⼀批样本作为训练数据进⾏训练。在优化过程中,把这些样本逐步代入优化器中。

这里需要定义两个占位符:⼀个是特征数据的占位符,即x,这个特征就是图片,每张图片是28×28大小的灰度图,所以它有784列,相当于把它每⼀行像素点都拉平,形成⼀个⼀维数组。因为需要批量进行训练,每次代入的样本是⼀个批量,所以对于行先设为None就可以了。另外一个是标签数据的占位符,即y,因为这个标签是转换成one hot形式的,是⼀个十分类的标签,所以列数为10列,行跟刚才的特征值⼀样,用None保留,允许后面产生多行。

> 定义模型变量

这⾥定义了两个变量:⼀个是权值W;另⼀个是偏置项b。

W相对来说复杂⼀点,它是一个784×10的二维数组,b只是⼀个10×1的一维数组。对于偏置项b,以常数0进⾏初始化即可。对于W,需要⽤符合正态分布的随机数做初始化,调⽤tf.random_normal,产生一个符合正态分布的、符合784×10这个shape的随机数。

下面具体解释一下tm.random_normal。

以上图的例⼦来说,先调⽤tm.random_normal函数生成100个随机数,生成一个张量赋给norm, 由于norm是⼀个操作,需要在Session里运⾏。运⾏完后,就可以把它的值输出。在上图中输出了前10个数。

因为总共有100个随机数,为了更好的观察,⽤图形化的模式更⽅便。

通过matplotlib中的直⽅图模式,把数据以直⽅图的形式打印出来。可以看到,它们是符合正态分布的。

> 定义前向计算

因为该例⼦⽤的模型很简单,只需⼀个神经元。其操作只是用输⼊特征值跟权值相乘然后相加,这⾥⽤的是矩阵的叉乘。把所有的x乘以相应的w的和累加起来,最后加上b,实现一个前向计算,并将结果赋给y。

在这个输入中有784个像素点,所以特征点的特征数据较多,如果从X1开始算起,n就是从1到784,forward就是图中所表⽰的y值。

但仅有前面这些操作还不够,因为这是一个分类应⽤,所以下⼀步需要对计算出来的y值进⾏结果分类。

> 结果分类

这里可以直接调⽤TensorFlow中提供的softmax函数来实现分类,即把计算出来的y值通过softmax函数转换成为它属于这十个类别中哪⼀类的概率值,转换后的y是一个有⼗个元素的向量,其中的每个分量表示属于对应类别的概率值。具体的算法会在下⼀⼩节中详细讲解。

我们会从简单的单个数值(比如房价)预测的问题,转移到多分类的问题,所⽤⽅法也会从线性回归转到逻辑回归。

> 逻辑回归

许多问题的预测结果是一个在连续空间的数值,比如房价预测问题,可以用线性模型来描述:

但也有很多场景需要输出的是概率估算值,例如:根据邮件内容判断是垃圾邮件的可能性;根据医学影像判断肿瘤是恶性的可能性;手写数字分别是0、1、2、3、4、5、6、7、8、9的可能性。这些都是分类问题。

要预测各类别的概率值,需要将预测输出值控制在[0,1]区间内。对于二元分类问题,它的目标是正确预测两个可能的标签中的一个,就像判断是不是垃圾邮件,答案只有是或否两种。此时可考虑⽤逻辑回归(Logistic Regression)来处理这类问题。

逻辑回归也叫回归,但本质上它能更好的处理分类问题。

> Sigmoid函数

如何保证输出值始终落在0和1之间呢?

这里有⼀个⾮常好的函数叫Sigmoid函数(S型函数),它的输出值正好具有以上特性。其定义如下:

它的定义域z∈(-∞,∞),值域y∈[0,1]。当z=0时,y=0.5。

整个函数呈S型,连续可微,这为后⾯做优化提供了很好的基础,因为后⾯用到的梯度下降优化算法是要对函数求偏导的。

针对Sigmoid函数,特定样本的逻辑回归模型的输出会是什么样呢?

如果把通过线性模型计算出来的z,代到sigmoid函数的公式中,就能形成下图中S型的函数曲线:

这样就能保证,不管z取什么值,y始终落在[0,1]的区间中。

对于⼆元分类,⽐如z=2时,y>0.5,此时就认为正的标签更符合。反之,如果y<0.5,就认为反面标签成⽴。比如垃圾邮件的判断,如果概率已经达到0.7,就认为它是垃圾邮件;当概率低于0.5,就认为它不是垃圾邮件。

通过sigmoid函数,可以把线性模型的输出映射到0到1之间的概率,从而实现⼆元分类的思想。

> 逻辑回归中的损失函数

在模型优化时需要计算损失函数,前面线性回归的损失函数是MSE均方差损失函数,即平方损失,如果把逻辑回归的损失函数也定义为平方损失,会得到下图的函数:

𝑖表示第𝑖个样本点;𝑧𝑖=𝑥𝑖∗𝑤+𝑏;𝜑(𝑧𝑖)表示对𝑖个样本的预测值,即把z的值映射到[0,1]中去;𝑦𝑖表示第𝑖个样本的标签值。

它所做的就是预测值减去标签值的平⽅再求均值,这个就是平⽅损失函数。

如果把sigmoid函数代入上图的公式中,得到完整的损失函数,这个损失函数是一个⾮凸的函数,是有多个极⼩值的。

⽤梯度下降优化算法就很有可能使优化过程陷⼊到局部最优的局部极⼩值中去, 使得函数不能达到全局最优。所以建议大家在逻辑回归中不要采⽤平⽅损失函数。

那这里应该采⽤什么样的损失函数呢?

对于二元逻辑回归的损失函数,一般采用对数损失函数,其定义如下:

(𝑥,𝑦)∈𝐷是有标签样本(𝑥,𝑦)的数据集;𝑦是有标签样本中的标签,因为这是一个二元逻辑回归,所以它的取值必须是0或1;𝑦'是对于特征集𝑥的预测值(介于0和1之间)。

为什么要采用这样的函数呢?

假设样本(x,y)中的标签值y=1,那么理想的预测结果y'=1。y与y'越接近,损失就越⼩。如果y'=1,上图公式的后半部分为0,整个损失函数就变成了-log(y')。根据对数函数的性质,当y'越接近1,-log(y')的值越⼩的。反之,如果标签值y=0,上图公式的前半部分为0,整个损失函数就变成了-log(1-y'),当y'越⼩时,损失值越⼩。

通过这样的对数损失函数,就能较好的刻画预测值和标签值之间的损失关系,⽽且这个损失函数是凸函数:

这样就能通过梯度下降的⽅法找到最优解。

> 多元分类和Softmax

> Softmax思想

之前已经提到逻辑回归可生成介于0和1之间的小数。例如,某电子邮件分类器的逻辑回归输出值为0.8,表明电子邮件是垃圾邮件的概率为80%,不是垃圾邮件的概率为20%。很明显,这封电子邮件是垃圾邮件与不是垃圾邮件的概率之和为1。

在处理多元分类中,Softmax将逻辑回归的思想延伸到多类别领域。

在多类别问题中,Softmax为每个类别分配一个小数形式的概率,介于0到1之间,并且这些概率的和必须是1。

下面看一个手写数字识别的例子:

上图左边部分的这张图像大概能够辨认是3,但并不一定是3,也可能是5或8,这与每个人的手写方式有关。MNIST数据集总共有十个类别,每张手写数字图像对应每个类别有一个概率,相似度越高,概率就越高,反之则越低。比如上图右边的表格里3的概率最大(0.721),说明这个手写体最有可能是的类别是3,而且这十个类别的概率相加一定等于1。这里就是把前面二元分类的sigmoid延伸到了多类别的情况,而不变的是所有类别的概率和还是等于1。

我们来看一下神经网络中Softmax层所处的地位:

前面提到,通过一个神经元实现线性运算得到y值,y值通过Softmax层做多分类的判别,对于每一个类别,我们的目标就是判断是否属于该类别。比如对0回答是或者否,对1回答是或者否,以此类推,不过0-9中只能有一个为是,其他全部为否。

Softmax层实际上是通过Softmax方程来实现,把y的值经过运算,映射到多分类问题中属于每个类别的概率值:

该公式如下图所示:

这里的yk指的是所有的类别,也就是这里y1-y9,十个分类。分母上对于每一个类别的值进行求和,分子取特定的值,得到的pi必然是一个大于0小于1的值。

此公式本质上是将逻辑回归公式延伸到了多类别。

接下来,通过一个例子来看看它如何计算:

有一个向量Y,它有四个元素,如果对它进行Softmax运算,实际上就是分别把值代入Softmax公式,最终得到右边的向量,经过运算后各个数字之间的差距更大了,而且它们都被映射到[0,1]区间,但它们在大小排序上保留了原本的次序。右边这个向量的每个取值就可以认为是该类别的概率。

Softmax的计算在TensorFlow 中已经实现了,所以直接调用即可。

> 交叉熵损失函数

对于多元分类又该采用什么样的损失函数呢?

这里要介绍一下交叉熵的概念:交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的。如果用在概率分布中,比如给定两个概率分布p和q,通过q来表示p的交叉熵如下图所示:

交叉熵刻画的是两个概率分布之间的距离,p代表正确答案,q代表的是预测值,交叉熵越小,两个概率的分布约接近,损失越低。对于机器学习中的多分类问题,通常用交叉熵做为损失函数。

下面来看一个交叉熵计算的例子:

假设有一个3分类问题,某个样例的正确答案是(1,0,0),即它属于第一个类别。甲模型经过softmax回归之后的预测答案是(0.5,0.2,0.3),乙模型经过softmax回归之后的预测答案是(0.7,0.1,0.2)。它们俩哪一个模型预测的更好一些呢(更接近正确答案)?

通过下面交叉熵的计算可以看到,乙模型的预测更好:

有了这个定义,把它用在TensorFlow中,loss值就可以用下图表示:

其中:𝑦𝑖为标签值,𝑦𝑖’为预测值。

在TensorFlow中,可以通过它提供的几个函数来实现:

这里的reduce_sum用来做累加,而reduce_mean是求均值,把均值作为损失函数的定义。

> 模型构建与训练实践

> 载入数据

首先引入TensorFlow包,并读取相应的数据,解压后赋给mnist。

> 构建模型

> 定义占位符X、Y

> 定义变量

> 用单个神经元构建神经网络

> 进行Softmax

> 训练模型

> 设置训练参数

这些训练参数大部分是超参数。目前设定训练轮数为50轮,因为这里打算采用小批量输入样本的方式,就涉及到一个批量要输入多少样本,需要定义batch_size(批次大小),这里设置为100。所有样本训练一轮所需要的批次数是用训练集样本数除以批次大小。在该例子中,训练集中图片数量是55000,那么除以batch_size=100,一个epoch训练需要550次。display_steps用来显示粒度,表示后面显示当前的损失及精确率的控制粒度。还有一个超参数是学习率,初始化为0.01。所谓的训练模型就是根据训练的结果去调整这些超参数。

> 定义损失函数

这里所定义的损失函数是之前所讲的交叉熵损失函数。

根据定义,利用TensorFlow提供的函数可以定义损失函数。y是标签值,现在是一个占位符,把它乘以预测值的对数,然后相加。最后,损失函数是要衡量一批样本预测结果的损失大小,所以还要取均值。损失函数的定义就是由这样一条语句所构成。

> 选择优化器

这里还是采用梯度下降的优化器,代入学习率,目标是最小化loss_function。

> 定义准确率

在分类问题中有一个新的特点——定义准确率。这个准确率是指对测试集中的全体样本,其预测的分类值和标签实际的分类值相等的百分比。

图中的tf.argmax中的argmax函数能返回数组中最大值的下标。tf.argmax(pred,1)是针对预测值的,tf.argmax(y,1)是针对标签的,只要判断这两个值是否相等即可得到分类准确率。如果相等则返回true,否则返回false。

agrmax 带了两个参数,第一个参数是y,第二个参数是一个数字,这个数字指的是针对的维度。因为这里是批量的代入数据,所以y是一个二维的矩阵而不是一维的向量,遇到这种多维的情况,第二个参数指的就是取最大值的下标时所针对的维度。

这里通过一个例子看一下argmax的用法:

首先导入两个必须的第三方的库tensorflow和numpy。然后定义两个数组,第一个arr1是一维数组,也就是一个向量,第二个arr2是一个四行三列的数组。从输出结果可以看出它们的维度。

对于argmax的用法,最主要关注的是后面这个参数。

如果第一个参数代入arr1,因为它是一个向量,这里不需要指定第二个参数就能计算出值来,这个结果通过后面的session运行可以得到。因为arr1中的最大值是7,下标为4,所以第一个输出是4。

如果第一个参数代入arr2,并且指定第二个参数为0,表示按第一维的元素进行取值,也就是行。因为行优先,行列的第一维是行,第二维也是最后一维,就是列。怎么理解按第一维的元素取值呢?其实就是同一列的每一行。运行后可以看到输出结果是3 2 0,只有三个值,针对的是每一列中所有的行。对照arr2的定义可以看到,在第一列中最大值为8,它的行下标为3,第二列最大值为7,行下标为2,第三列最大值为3,行下标为0。当第二个参数为1,表示对第二维(即所有的列)的元素进行取值,就是同一行的每一列,所以出来的结果是2 0 1 0。这跟我们定义准确率时所带的参数一致,也是1。如果把arr2当作预测值所对应数组,相当于有四个样本的三分类问题。那结果也需要有四行(每个样本一个预测结果),只不过是返回每一列的最大值所对应的下标,所以第二个参数设置为1。

当第二个参数为-1,表示最后一维。在python中讲列表的时候有提到过,可以用-1表示最后。对于二维数组来说,最后一维其实还是列,所以这里用1和-1的效果是一样的,其输出结果也是相同的。

明白了argmax之后,再回到定义准确率,correct_prediction得到的实际上是tf.equal比较后的值,tf.equal比较的是它的第一个参数和第二个参数是否相等,相等为真(返回1),不相等为假(返回0)。但在运算的时候,需要把0或1的布尔值转换成TensorFlow中的浮点数。

这里用tf.cast做投射,将其转换成为浮点数,转换之后就可以对批量的样本求准确率了。因为它的值有0有1,比如总共100个样本,其中89个是1,11个是0,就可以知道其中89个样本的预测值和标签值匹配上了,还有11个没有匹配。根据求均值得出预测的准确率为0.89。

然后进行会话声明和初始化变量。

> 模型训练

第一个for循环是训练轮数,总共训练50轮(train_epochs=50在前面已定义)。

第二个for 循环是根据前面定义的batch_size指定要运行多少批次(total_batch)。batch_size在前面设置的是100,所以每一次会读入100个样本的值,并将这100个样本的特征值赋给xs,标签值赋给ys。然后把xs和ys代入优化器进行一次训练,即用读入的这100个样本做优化。当这个for循环执行完后,在训练集的55000条样本上的一轮训练就结束了。注意,这里不是用训练数据,而是用验证集的数据重新填充sess.run()的对应参数,算出当前模型的损失及准确率,并将结果赋给loss和acc。然后把这一轮训练的结果输出。以上是每一轮需要做的事情,当50轮全部训练完,最后输出"Train Finished!"。

55000条数据训练起来还是比较快的。可以看到loss值是呈下降趋势的,而准确率逐步上升,且上升的较快,到第30轮时,准确率已经到84%了。

这里总共是50轮的训练周期,但准确率实际上还在不断上升,最后准确率应该是能达到90%的。大家可以尝试自己修改这个模型,使它的准确率达到90%以上。

> 评估模型

模型训练完后就需要评估模型,训练出来的模型面对新的数据是否也有效。在前面训练过程中用到的数据,只有训练集和验证集,而测试集还没有跟模型见过面,所以测试数据是严格保密的,就像高考试卷一样。

现在模型训练好了,将通过测试集的数据来进行测试。在测试环节中没有分批,所有的测试数据都是一次就读入了,即一万条测试样本直接执行并输出结果,这个效率还是挺高的。

从运行结果可以看到,测试集上的精度是86.46%,前面用验证集的数据测出来的精度是86.06%,两者基本没有太大的差别。

如果大家还感兴趣的话,可以再用验证集的数据重新跑一下这个模型,得到的结果跟刚才训练的结果是一致的,86.06%。

再看看它在训练集上的效果怎么样,得到的结果是85.5364%。

通过对数据集的合理划分,它在训练集和验证集的效果基本上是差不多的,如果对这个准确率还比较满意,就可以进行模型的应用了。

> 模型应用与可视化

> 应用模型

模型应用的这句话非常简单,只需要用到前面已经定义好的pred操作以及把需要进行预测的图像或图像样本集(mnist.test.images)作为输入填充到占位符x中,之后通过argmax函数把pred得到的one hot形式的结果转换成所需要的0-9的数字,就能得到预测结果。这里输入的参数是测试集的样本:

查看预测结果中的前十项,就可以看到对应的结果:

这是一个10个元素的一维数组,比如第一个元素7代表了第一幅图像的预测结果是数字7,2代表了第二幅图像的预测结果是2,以此类推。

> 定义可视化函数

由于这样的输出结果没有跟输入的图像对应起来,不太直观,所以可以通过定义可视化函数,使得输入图像的标签和预测结果同时显现出来,这样效果更好。

调用了plot_images_labels_prediction()函数并代入相关参数后,得到的可视化效果如上图所示。

这里显示的图像是训练得到的图像,每张图像上面是标题栏,标题栏中显示了这张图像的标签值和预测值。每行显示5张图片,共显示了10张图片。

这个函数最多可以显示25张图片:

首先把需要的库引入进来,导入matplotlib和numpy。

然后定义函数名为plot_images_labels_prediction,它共有五个参数:第一个参数是图像的数据列表;第二个参数是标签列表;第三个参数是预测值列表;第四个参数是下标,表示从列表的第几项开始显示;第五个参数是数量,表示一次显示多少幅图片,这里把缺省值设置为10,如果不给num赋值,那么它默认显示的是10幅图像。

再看函数内部的语句,第一句是获取当前图表,通过plt中的gcf ( get current figure )函数得到当前图表,然后把它赋给变量fig。下一步是设置当前图像的大小,因为要显示很多幅图像,所以设得比较大,宽和高分别是10英寸和12英寸。接下来对输入的参数num做了一个判断,如果num>25,则最多赋值为25,即每次显示的图像最多只有25张。

针对每一幅图像的处理用到了一个循环num次的for循环。首先通过plt的sunplot函数获取当前要处理的子图。因为子图最多有25张,所以它的尺寸是5×5,从1开始到num + 1,一个个来处理。用imshow函数把当前样本中第index张图像显示出来。接下来为这幅图像设置title信息,也就是图像上的标题栏,它的显示为“lable= ”,后面跟上当前标签的值。接下来判断预测值列表是否为空,因为有时候图像可能仅用于训练而没有预测值,如果不为空,会在title上显示预测值。然后设置一些显示的参数,比如把title的字体大小设为10。另外,为了美观,这里不显示x和y的坐标轴。最后把index+1,进行下一轮循环。

当num次for循环运行完毕,即所有子图全部构建完成后,用plt的show函数显示结果,就得到了我们开头所看到的结果。

如果我们预测的时候把prediction_result改成“[]”,即空列表,也是可以的,只不过在标题栏中不会显示predict的值,只显示label的值,这时候就是一个训练集的结果。

示例代码

argmax()用法.py
import tensorflow as tf
import numpy as np

arr1 = np.array([1,3,2,5,7,0])
arr2 = np.array([[1.0,2,3],[3,2,1],[4,7,2],[8,3,2]])
print("arr1=",arr1)
print("arr2=\n",arr2)

argmax_1 = tf.argmax(arr1)
argmax_20 = tf.argmax(arr2,0)#指定第二个参数为0,按第一维(行)的元素取值,即同列的每一行
argmax_21 = tf.argmax(arr2,1)#指定第二个参数为1,则第二维(列)的元素取值,即同行的每一列
argmax_22 = tf.argmax(arr2,-1) #指定第二个参数为-1,则第最后维的元素取值

with tf.Session() as sess:
    print(argmax_1.eval())
    print(argmax_20.eval())
    print(argmax_21.eval())
    print(argmax_22.eval())

with tf.Session() as sess:
    print(tf.nn.softmax(arr2).eval())
    print(tf.nn.softmax(arr2,0).eval())
    print(tf.nn.softmax(arr2,1).eval())
了解 tf.random_normal.py
norm = tf.random_normal([100]) #生成100个随机数
with tf.Session() as sess:
    norm_data=norm.eval()
print(norm_data[:10])          #打印前10个随机数

import matplotlib.pyplot as plt
plt.hist(norm_data)
plt.show()

> 基于单个神经元的手写数字识别

载入数据.py
import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
构建模型.py
# 定义x和y的占位符
# mnist 中每张图片共有28*28=784个像素点
x = tf.placeholder(tf.float32, [None, 784], name="X") 

# 0-9 一共10个数字=> 10 个类别
y = tf.placeholder(tf.float32, [None, 10], name="Y")  

# 定义变量
W = tf.Variable(tf.random_normal([784, 10]),name="W")
b = tf.Variable(tf.zeros([10]),name="b") 

forward=tf.matmul(x, W) + b # 前向计算

pred = tf.nn.softmax(forward) # Softmax分类
训练模型.py
# 设置训练参数
train_epochs = 50 # 训练轮数
batch_size = 100  # 单次训练样本数(批次大小)
total_batch= int(mnist.train.num_examples/batch_size)  # 一轮训练有多少批次
display_step = 1  # 显示粒度
learning_rate= 0.01  # 学习率

# 定义交叉熵损失函数
loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), 
                                              reduction_indices=1)) 
                                              
# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

 # 定义准确率
 # 检查预测类别tf.argmax(pred, 1)与实际类别tf.argmax(y, 1)的匹配情况
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# 准确率,将布尔值转化为浮点数,并计算平均值
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

sess = tf.Session() #声明会话
init = tf.global_variables_initializer() # 变量初始化
sess.run(init)

# 开始训练
for epoch in range(train_epochs ):
    for batch in range(total_batch):
        xs, ys = mnist.train.next_batch(batch_size)# 读取批次数据
        sess.run(optimizer,feed_dict={x: xs,y: ys}) # 执行批次训练
    
    #total_batch个批次训练完成后,使用验证数据计算误差与准确率;验证集没有分批   
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, y: mnist.validation.labels})
    # 打印训练过程中的详细信息
    if (epoch+1) % display_step == 0:
        print("Train Epoch:", '%02d' % (epoch+1), "Loss=", "{:.9f}".format(loss),\
              " Accuracy=","{:.4f}".format(acc))

print("Train Finished!")  
评估模型.py
# 测试集
accu_test =  sess.run(accuracy,
                      feed_dict={x: mnist.test.images, y: mnist.test.labels})

print("Test Accuracy:",accu_test)

# 验证集
accu_validation =  sess.run(accuracy,
                      feed_dict={x: mnist.validation.images, y: mnist.validation.labels})

print("Test Accuracy:",accu_validation)

# 训练集
accu_train =  sess.run(accuracy,
                      feed_dict={x: mnist.train.images, y: mnist.train.labels})

print("Test Accuracy:",accu_train)
进行预测.py
# 由于pred预测结果是one-hot编码格式,所以需要转换为0~9数字
prediction_result=sess.run(tf.argmax(pred,1), 
                           feed_dict={x: mnist.test.images })
                           
#查看预测结果中的前10项
prediction_result[0:10] 
定义可视化函数.py
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,      # 图像列表
                                  labels,      # 标签列表
                                  prediction,  # 预测值列表
                                  index,       # 从第index个开始显示
                                  num=10 ):    # 缺省一次显示 10 幅
    fig = plt.gcf() # 获取当前图表,Get Current Figure
    fig.set_size_inches(10, 12)  # 1英寸等于 2.54 cm
    if num > 25: 
        num = 25            # 最多显示25个子图
    for i in range(0, num):
        ax = plt.subplot(5,5, i+1) # 获取当前要处理的子图
        
        ax.imshow(np.reshape(images[index],(28, 28)),  # 显示第index个图像
                  cmap='binary')
            
        title = "label=" + str(np.argmax(labels[index]))  # 构建该图上要显示的title信息
        if len(prediction)>0:
            title += ",predict=" + str(prediction[index]) 
            
        ax.set_title(title,fontsize=10)   # 显示图上的title信息
        ax.set_xticks([]);  # 不显示坐标轴
        ax.set_yticks([])        
        index += 1 
    plt.show()
    
    
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,10,25)

Last updated