来呀,快活呀~

TFRecord 简介

TFRecord是TensorFlow中常用的数据打包格式。通过将训练数据或测试数据打包成TFRecord文件,就可以配合TF中相关的DataLoader / Transformer等API实现数据的加载和处理,便于高效地训练和评估模型。

TF官方tutorial:TFRecord and tf.Example

TFRecord好!

组成TFReocrd的砖石:tf.Example

tf.Example是一个Protobuffer定义的message,表达了一组string到bytes value的映射。TFRecord文件里面其实就是存储的序列化的tf.Example。如果对Protobuffer不熟悉,可以去看下Google的文档教程

Example 是什么

我们可以具体到相关代码去详细地看下tf.Example的构成。作为一个Protobuffer message,它被定义在文件core/example/example.proto中:

1
2
3
message Example {
Features features = 1;
};

好吧,原来只是包了一层Features的message。我们还需要进一步去查找Features的message定义

1
2
3
4
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};

到这里,我们可以看出,tf.Example确实表达了一组string到Feature的映射。其中,这个string表示feature name,后面的Feature又是一个message。继续寻找:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};

// 这里摘一个 Int64List 的定义如下,float/bytes同理
message Int64List {
// 可以看到,如其名所示,表示的是int64数值的列表
repeated int64 value = 1 [packed = true];
}

看起来,是描述了一组各种数据类型的list,包括二进制字节流,float或者int64的数值列表。

属于自己的Example

有了上面的分解,要想构造自己数据集的tf.Example,就可以一步步组合起来。

首先用下面的几个帮助函数,将给定的Python类型数据转换为对应的Feature。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# The following functions can be used to convert a value to a type compatible
# with tf.Example.

def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 这里我们直接认为value是个标量,如果是tf.Tensor,可以使用
# `tf.io.serialize_tensor`将其序列化为bytes
# `tf.io.parse_tensor`可以反序列化为tf.Tensor

def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

有了Feature,就可以组成Features,只要把对应的名字作为string传进去就行了。

1
2
features_dict = {'image': _bytes_feature(image_data), 'label': _int64_feature(label)}
features = tf.train.Features(feature=features_dict)

Example自然也就有了:

1
example = tf.train.Example(features=features)

TFRecord

TFRecord是一个二进制文件,只能顺序读取。它的数据打包格式如下:

1
2
3
4
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data

其中,data[length]通常是一个Example序列化之后的数据。

Example写入TFRecord

可以使用python API,将Exampleproto写入TFRecord文件。

1
2
3
4
5
6
7
8
with tf.io.TFRecordWriter(filename) as writer:
for image_file in image_files:
image_data = open(image_file, 'rb').read()
features = tf.train.Features(feature={'image': _bytes_feature(image_Data)})
# 得到 example
example = tf.train.Example(features=features)
# 通过调用message.SerializeToString() 将其序列化
writer.write(example.SerializeToString())

读取TFRecord中的Example

通过tf.data.TFRecordDataset得到Dataset,然后遍历它,并反序列化,就可以得到原始数据。下面的代码段从TFRecord文件中读取刚刚写入的image:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def parse_from_single_example(example_proto):
""" 从example message反序列化得到当初写入的内容 """
# 描述features
desc = {'image': tf.io.FixedLenFeature([], dtype=tf.string)}
# 使用tf.io.parse_single_example反序列化
return tf.io.parse_single_example(example_proto, desc)


def decode_image_from_bytes(image_data):
""" use cv2.imdecode decode image from raw binary data """
bytes_array = np.array(bytearray(image_data))
return cv2.imdecode(bytes_array, cv2.IMREAD_COLOR)


def get_image_from_single_example(example_proto):
""" get image fom example serialized data """
data = parse_from_single_example(example_proto)
image_data = data['image'].numpy()
# the image_data is str
# decode the binary bytes to get the image
return decode_image_from_bytes(image_data)


dataset = tf.data.TFRecordDataset(tfrecord_file)
data_iter = iter(dataset)
first_example = next(data_iter)

first_image = get_image_from_single_example(first_example)

或者可以用map来将parser的pipeline应用于原dataset:

1
2
3
4
5
6
# 注意这里不能用get_image_from_single_example
# 因为 `.numpy()` 不能用于静态 Map
image_data = dataset.map(parse_from_single_example)

first_image_data = next(iter(image_data))
image = decode_image_from_bytes(first_image_data['image'].numpy())