写入和读取TFRecord文件

Note

上节中我们了解到TFRecord文件通常包含的是序列化的Example数据。
本节来讲如何写入和读取这类文件。

写入TFRecord文件

import tensorflow as tf

# for simplicity
BytesList = tf.train.BytesList
FloatList = tf.train.FloatList
Int64List = tf.train.Int64List
Feature = tf.train.Feature
Features = tf.train.Features
Example = tf.train.Example
# 姓名,id,邮箱
data = [[[b"Alice"], [123], [b"alice@a.com", b"alice@b.com"]],
        [[b"Bob"], [22], [b"Bob@c.com"]]]
with tf.io.TFRecordWriter("my_contacts.tfrecord") as f:
    for lst in data:
        # 创建此数据的Example
        person_example = Example(
            features=Features(
                # value都是列表
                feature={
                    "name": Feature(bytes_list=BytesList(value=lst[0])),
                    "id": Feature(int64_list=Int64List(value=lst[1])),
                    "emails": Feature(bytes_list=BytesList(value=lst[2]))
                }))
        # 写入序列化的Example
        f.write(person_example.SerializeToString())

读取TFRecord文件

读取数据时,首先要给出Example的description。

feature_description = {
    "name": tf.io.FixedLenFeature([], tf.string, default_value=""),
    "id": tf.io.FixedLenFeature([], tf.int64, default_value=0),
    "emails": tf.io.VarLenFeature(tf.string),
}

单个解析

for serialized_example in tf.data.TFRecordDataset(["my_contacts.tfrecord"]):
    # 解析Example
    parsed_example = tf.io.parse_single_example(serialized_example,
                                                feature_description)
    # 然后就可以获得数据啦
    print(parsed_example)
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f874a14a6d0>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f874a14ac10>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=22>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Bob'>}

批量解析

批量解析的话要在TFRecordDataset中指定batch,并用 tf.io.parse_example 解析函数替代 tf.io.parse_single_example。

for serialized_example in tf.data.TFRecordDataset(["my_contacts.tfrecord"]).batch(10):
    # 解析Example
    parsed_examples = tf.io.parse_example(serialized_example,
                                         feature_description)
    print(parsed_examples)
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8751435970>, 'id': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([123,  22])>, 'name': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Alice', b'Bob'], dtype=object)>}

预处理数据

有时候光解析数据是不够的,数据还需要经过预处理才能被模型接受,这时需要自定义所需的预处理函数然后map。

def _parse_person(example_proto):
    # 将id乘2并只使用这个特征
    parsed_example = tf.io.parse_single_example(example_proto, 
                                                feature_description)
    person_id = parsed_example["id"]
    return person_id * 2


# 原始的dataset
raw_dataset = tf.data.TFRecordDataset(["my_contacts.tfrecord"])
# 预处理后的dataset
dataset = raw_dataset.map(_parse_person, num_parallel_calls=4)
# as expected
for X in dataset:
    print(X)
tf.Tensor(246, shape=(), dtype=int64)
tf.Tensor(44, shape=(), dtype=int64)