Here is an example with 3 workers.
dataset = tf.data.Dataset.range(6)
dataset = dataset.shard(FLAGS.num_workers, FLAGS.worker_index)
iterator = dataset.make_one_shot_iterator()
res = iterator.get_next()
#Suppose you have 3 workers in total
with tf.Session() as sess:
for i in range(2):
print(sess.run(res))
We will have the output:
0, 3 on worker 0
1, 4 on worker 1
2, 5 on worker 2
网友评论