在用tf2.0进行自定义训练的时候,用tf.data模块进行输入数据的创建!结果在训练的时候出现:训练完第一个epoch后,不报任何错误停止训练!
发现问题出于:在对训练数据进行“打散”操作的时候,没有加repeat()进行重复随机打散操作!修改后如下:
数据打散、划分部分:
# 乱序、批划分:
BATCH_SIZE = 16
# 训练数据:
train_count = len( train_image_path ) # 总数据个数
train_dataset = train_dataset.shuffle(train_count).repeat().batch(BATCH_SIZE) # 乱序 + 划分批次
train_dataset = train_dataset.prefetch(AUTOTUNE) # 新操作:预取到缓存:加速处理
# 测试数据:
test_dataset = test_dataset.batch(BATCH_SIZE) # 乱序 + 划分批次
test_dataset = test_dataset.prefetch(AUTOTUNE) # 新操作:预取到缓存:加速处理
训练部分:
history = model.fit(
train_dataset,
epochs = 10,
steps_per_epoch = len(train_image_path) // BATCH_SIZE,
validation_data = test_dataset,
validation_steps = len(test_image_path) // BATCH_SIZE
)
网友评论