AI测试 TensorFlow 学习记录 (二:MNIST 数据集显示)

膨化先生 · May 08, 2019 · 2419 hits

一、访问MNIST数据集

TensorFlow 学习记录 (一)中我们知道MNIST包括训练集、测试集和验证集,通过MNIST对象可以访问它们。

from TensorFlow import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 访问MNIST训练集
print(mnist.train.images.shape)
print(mnist.train.labels.shape)
# 访问MNIST测试集
print(mnist.test.images.shape)
print(mnist.test.labels.shape)
# 访问MNIST验证集
print(mnist.validation.images.shape)
print(mnist.validation.labels.shape)

结果是:
(55000, 784)
(55000, 10)
(10000, 784)
(10000, 10)
(5000, 784)
(5000, 10)

可以看到与给定相同,训练集有55000张图片,测试集有10000张图片,验证集有5000张图片。image有784节点,也就是28*28像素;label有10个节点,也就是0~9十个数字。

二、MNIST图片数组显示

from TensorFlow import input_data
import numpy as np

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 设置输出结果保留精度为1位小数
np.set_printoptions(precision=1)

print(mnist.train.images[1, :].reshape(28, 28))

查看数组显示,为方便观看对概率的精度只保留了一位小数:

三、MNIST图片可视化

from TensorFlow import input_data
import matplotlib.pyplot as plt

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# 获取第二张图片
image = mnist.validation.images[1, :]
# 将图像数据还原成28*28的分辨率
image = image.reshape(28, 28)
# 打印对应的标签
print(mnist.validation.labels[1])

plt.figure()
plt.imshow(image)
plt.show()

使用plt库将MNIST数据集图片可视化,可以看到图片中显示数字为3,标签为[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.],也为3

如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!
共收到 0 条回复 时间 点赞
需要 Sign In 后方可回复, 如果你还没有账号请点击这里 Sign Up