TensorFlow中tfrecords的使用

在训练模型的时候会有成千上万图片数据用来训练,如果一张张加载图片有两个缺点:一是速度慢,二是因为训练时还需要对图片进行增强(随即翻转、随机亮度调整等),会使程序代码变得复杂。

TFRecord 便可以解决此问题,它能将数据进行序列化并存储在一个或一组文件中,这样就可以高效的读取数据。

也就是将所有图片转换成二进制信息,全部存放在一个文件中。用的时候直接读取这个文件,并转换成dataset来使用。当然,二进制文件存放的还有图片的结构化信息(label、shape等)。

TERecord的官方介绍链接:https://www.tensorflow.org/tutorials/load_data/tfrecord#write_the_tfrecord_file 。不知道是版本原因还是什么,官方介绍中有的地方运行不出来。

TFRecords使用方法:

1.首先创建tfrecords文件

  • 读取图片的二进制信息(二进制信息和字节信息应该一样)
  • 创建一个存储此图片的字典,内容包括图片二进制信息、label、shape等需要用到的东西
  • 将此字典转化成 tf 的Example 格式,也就是将信息再格式化一下。
  • 利用相关库将生成的 Example 写入文件,文件名可自取。
    (自己理解:这里类似于python的open、write函数,打开一个文件,一次写一张图片信息,一张图片占一行,写完所有图片后关闭此文件。之后读文件的时候直接先读取成dataset,一行是一条数据,然后利用dataset.map()解析每行数据)

2.读取文件并生成dataset

  • 利用生成的tfrecords文件构建成dataset,dataset中每条数据是一张未解析的二进制图片信息。
  • 编写解析的函数,利用dataset.map() 解析dataset中的图片。
  • 解析后就得到可以使用的dataset了。

写入TFRecords文件的代码:

文件夹中有3中图片:1.jpg、2.jpg、3.jpg ,label 分别是1、2、3。

  1. 首先编写类型转换函数,因为写入的信息需要用固定的类型
1
2
3
4
5
6
7
8
9
10
11
12
13
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  1. 编写将所有图片信息写入TFRecords文件的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 写入TFRecords文件的函数,传入图片名列表、标签列表
def image2tfrecord(images_name,labels):
# 打开文件,类似open() 函数
record_file = "img.tfrecods"
writer = tf.io.TFRecordWriter(record_file)

# 遍历图片,并写入文件
for i, image in enumerate(images_name):
# 读取图片的二进制信息
image_string = tf.gfile.GFile(image, 'rb').read()

# 获取图片的形状 (也可以不获取,看自己需要)
image_shape = sess.run(tf.image.decode_jpeg(image_string)).shape
print(image_shape)

# 构建存放图片信息的字典
feature = {
'height': _int64_feature(image_shape[0]), # 需要将数据转化成固定的类型
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(labels[i]),
'image_raw': _bytes_feature(image_string),
}

# 将构建的字典转换成固定的数据类型来写入
img_example = tf.train.Example(features=tf.train.Features(feature=feature))
# 写入图片信息
writer.write(img_example.SerializeToString())

# 全部图片写完后关闭文件
writer.close()
  1. 调用写入函数,生成TFRecords文件
1
2
3
images = ['1.jpg', '2.jpg', '3.jpg']
labels = [1,2,3]
image2tfrecord(images,labels)

读取TFRecords文件的代码:

  1. 利用生成的TFRecords文件生成dataset
1
raw_image_dataset = tf.data.TFRecordDataset('img.tfrecods')
  1. 编写解析图片的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
# 先定义图片解析的格式
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}

# 图片解析函数
def _parse_image_function(example_proto):
# 将单张图片解析成刚才定义的格式
return tf.io.parse_single_example(example_proto, image_feature_description)
  1. 利用dataset.map()解析每张图片
1
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
  1. 然后可以查看一下解析的图片
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 定义一个迭代器,查看一些解析后的图片
iterator = parsed_image_dataset.make_one_shot_iterator()

while 1:
next_element = iterator.get_next()
# dataset中只有三张,读取完会退出循环
try:
image_raw = next_element['image_raw']
# 解码图片,同时将tensor转换成三维数组
image = sess.run(tf.image.decode_image(image_raw))
print(image) # 可以打印出每个像素的值

plt.imshow(image)
plt.show()
except:
break

全部代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import tensorflow as tf
import matplotlib.pyplot as plt


sess = tf.Session()

################ 写入图像##########

# 数据类型转换函数
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 写入TFRecords文件的函数,传入图片名列表、标签列表
def image2tfrecord(images_name,labels):
record_file = "img.tfrecods"
writer = tf.io.TFRecordWriter(record_file)

# 遍历图片,并写入文件
for i, image in enumerate(images_name):
# 读取图片的二进制信息
image_string = tf.gfile.GFile(image, 'rb').read()

# 获取图片的形状 (也可以不获取,看自己需要)
image_shape = sess.run(tf.image.decode_jpeg(image_string)).shape
print(image_shape)

# 构建存放图片信息的字典
feature = {
'height': _int64_feature(image_shape[0]), # 需要将数据转化成固定的类型
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(labels[i]),
'image_raw': _bytes_feature(image_string),
}

# 将构建的字典转换成固定的数据类型来写入
img_example = tf.train.Example(features=tf.train.Features(feature=feature))
# 写入图片信息
writer.write(img_example.SerializeToString())

# 全部图片写完后关闭文件
writer.close()


images = ['1.jpg', '2.jpg', '3.jpg']
labels = [1,2,3]
image2tfrecord(images,labels)


############ 读取##################
# 利用刚才生成的TFRecords文件生成dataset
raw_image_dataset = tf.data.TFRecordDataset('img.tfrecods')

# 定义图片解析的格式
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}

# 图片解析函数
def _parse_image_function(example_proto):
# 将单张图片解析成刚才定义的格式
return tf.io.parse_single_example(example_proto, image_feature_description)

# 利用map() 来解析dataset中的每张图片
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)

# 定义一个迭代器,查看一些解析后的图片
iterator = parsed_image_dataset.make_one_shot_iterator()

while 1:
next_element = iterator.get_next()
try:
image_raw = next_element['image_raw']

image = sess.run(tf.image.decode_image(image_raw))
print(image)

plt.imshow(image)
plt.show()
except:
break



sess.close()