TensorFlow读取单张图片并显示

我们知道图片数据是一个三维数组,每个像素点有RGB三个值,图片源文件是将这个三维数组按一定形式编码然后存储的。

所以TensorFlow读取图片的步骤是:先直接按字节将图片源文件读取出,然后decode解码成三维张量(数组),张量再转换成ndarray,然后就可以显示图像了。

其中读取图片的原始信息的函数有两个:
tf.gfile.FastGFile(‘test.jpg’, ‘rb’).read()
tf.read_file(‘test.jpg’)

下面是具体代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import tensorflow as tf
import matplotlib.pyplot as plt

# 1. tf.gfile.FastGFile(filename, mode), mode='r':图片utf-8编码, mode='rb':非utf-8编码
def readImg_1():
# 读入二进制文件
img_raw = tf.gfile.FastGFile('test.jpg', 'rb').read()

# 根据图片格式解码图片
img = tf.image.decode_jpeg(img_raw) # Tensor

with tf.Session() as sess:
img_ = img.eval() # ndarray
print(img_.shape)

plt.imshow(img_)
plt.show()

# 2.通过tf.read_file()读入图片文件
def readImg_2():
img_raw = tf.read_file('test.jpg')
img = tf.image.decode_jpeg(img_raw, channels=3)

with tf.Session() as sess:
img_ = img.eval() # 必须在会话中才能eval取值
print(img_.shape)

plt.imshow(img_)
plt.show()

readImg_1()
readImg_2()

同时 TensorFlow 还可以以队列的形式一下读取很多图片,然后供机器学习来训练。

还可以使用dataset构建数据集,操作和训练都更为方便。