dataset = tf.data.TFRecordDataset([example_store_file])
dataset = dataset.map(self.parse_example)
dataset = dataset.repeat(repeat)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(flags.batch_size)
当shuffle buffer_size过大时,会报tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: labels. Can't parse serialized Example. 这个问题网上很难找到答案,只有不断调参数才能测试,见鬼,Tensorflow api经常变,都想换Pytorch了。
buffer_size:1.当无batch时,即为dataset item
2.当有batch时,每个item为一个batch

这里有详细参数说明
https://user-gold-cdn.xitu.io/2018/8/28/16580f396628e48b
网友评论