写入和读取TFRecord文件¶
Note
上节中我们了解到TFRecord文件通常包含的是序列化的Example数据。
本节来讲如何写入和读取这类文件。
写入TFRecord文件¶
import tensorflow as tf
# for simplicity
BytesList = tf.train.BytesList
FloatList = tf.train.FloatList
Int64List = tf.train.Int64List
Feature = tf.train.Feature
Features = tf.train.Features
Example = tf.train.Example
# 姓名,id,邮箱
data = [[[b"Alice"], [123], [b"alice@a.com", b"alice@b.com"]],
[[b"Bob"], [22], [b"Bob@c.com"]]]
with tf.io.TFRecordWriter("my_contacts.tfrecord") as f:
for lst in data:
# 创建此数据的Example
person_example = Example(
features=Features(
# value都是列表
feature={
"name": Feature(bytes_list=BytesList(value=lst[0])),
"id": Feature(int64_list=Int64List(value=lst[1])),
"emails": Feature(bytes_list=BytesList(value=lst[2]))
}))
# 写入序列化的Example
f.write(person_example.SerializeToString())
读取TFRecord文件¶
读取数据时,首先要给出Example的description。
feature_description = {
"name": tf.io.FixedLenFeature([], tf.string, default_value=""),
"id": tf.io.FixedLenFeature([], tf.int64, default_value=0),
"emails": tf.io.VarLenFeature(tf.string),
}
单个解析¶
for serialized_example in tf.data.TFRecordDataset(["my_contacts.tfrecord"]):
# 解析Example
parsed_example = tf.io.parse_single_example(serialized_example,
feature_description)
# 然后就可以获得数据啦
print(parsed_example)
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f874a14a6d0>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f874a14ac10>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=22>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Bob'>}
批量解析¶
批量解析的话要在TFRecordDataset中指定batch,并用 tf.io.parse_example 解析函数替代 tf.io.parse_single_example。
for serialized_example in tf.data.TFRecordDataset(["my_contacts.tfrecord"]).batch(10):
# 解析Example
parsed_examples = tf.io.parse_example(serialized_example,
feature_description)
print(parsed_examples)
{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8751435970>, 'id': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([123, 22])>, 'name': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Alice', b'Bob'], dtype=object)>}
预处理数据¶
有时候光解析数据是不够的,数据还需要经过预处理才能被模型接受,这时需要自定义所需的预处理函数然后map。
def _parse_person(example_proto):
# 将id乘2并只使用这个特征
parsed_example = tf.io.parse_single_example(example_proto,
feature_description)
person_id = parsed_example["id"]
return person_id * 2
# 原始的dataset
raw_dataset = tf.data.TFRecordDataset(["my_contacts.tfrecord"])
# 预处理后的dataset
dataset = raw_dataset.map(_parse_person, num_parallel_calls=4)
# as expected
for X in dataset:
print(X)
tf.Tensor(246, shape=(), dtype=int64)
tf.Tensor(44, shape=(), dtype=int64)