在训练模型的时候会有成千上万图片数据用来训练,如果一张张加载图片有两个缺点:一是速度慢,二是因为训练时还需要对图片进行增强(随即翻转、随机亮度调整等),会使程序代码变得复杂。
TFRecord 便可以解决此问题,它能将数据进行序列化并存储在一个或一组文件中,这样就可以高效的读取数据。
也就是将所有图片转换成二进制信息,全部存放在一个文件中。用的时候直接读取这个文件,并转换成dataset来使用。当然,二进制文件存放的还有图片的结构化信息(label、shape等)。
TERecord的官方介绍链接:https://www.tensorflow.org/tutorials/load_data/tfrecord#write_the_tfrecord_file 。不知道是版本原因还是什么,官方介绍中有的地方运行不出来。
TFRecords使用方法:
1.首先创建tfrecords文件
- 读取图片的二进制信息(二进制信息和字节信息应该一样)
- 创建一个存储此图片的字典,内容包括图片二进制信息、label、shape等需要用到的东西
- 将此字典转化成 tf 的Example 格式,也就是将信息再格式化一下。
- 利用相关库将生成的 Example 写入文件,文件名可自取。
(自己理解:这里类似于python的open、write函数,打开一个文件,一次写一张图片信息,一张图片占一行,写完所有图片后关闭此文件。之后读文件的时候直接先读取成dataset,一行是一条数据,然后利用dataset.map()解析每行数据)
2.读取文件并生成dataset
- 利用生成的tfrecords文件构建成dataset,dataset中每条数据是一张未解析的二进制图片信息。
- 编写解析的函数,利用dataset.map() 解析dataset中的图片。
- 解析后就得到可以使用的dataset了。
写入TFRecords文件的代码:
文件夹中有3中图片:1.jpg、2.jpg、3.jpg ,label 分别是1、2、3。
- 首先编写类型转换函数,因为写入的信息需要用固定的类型
1 | def _bytes_feature(value): |
- 编写将所有图片信息写入TFRecords文件的函数
1 | # 写入TFRecords文件的函数,传入图片名列表、标签列表 |
- 调用写入函数,生成TFRecords文件
1 | images = ['1.jpg', '2.jpg', '3.jpg'] |
读取TFRecords文件的代码:
- 利用生成的TFRecords文件生成dataset
1 | raw_image_dataset = tf.data.TFRecordDataset('img.tfrecods') |
- 编写解析图片的函数
1 | # 先定义图片解析的格式 |
- 利用dataset.map()解析每张图片
1 | parsed_image_dataset = raw_image_dataset.map(_parse_image_function) |
- 然后可以查看一下解析的图片
1 | # 定义一个迭代器,查看一些解析后的图片 |
全部代码:
1 | import tensorflow as tf |