Slice in tensorflow is very similar to numpy, which can be found in TensorFlow API:
tf.slice
tf.slice(
input_, # input tensor
begin, # begin location
size, # output tensor size
name=None # name of operation
)
a
equals b
since tf.slice stride is [1,1,1]
by default.
import tensorflow as tf
input = [[[1,1,1],[2,2,2],[3,3,3]],
[[4,4,4],[5,5,5],[6,6,6]],
[[7,7,7],[8,8,8],[9,9,9]]]
sess = tf.Session()
a = tf.slice(input, [0,1,0], [2,2,2])
print(sess.run(a))
print("\n")
b = tf.strided_slice(input, [0,1,0], [2,3,2], [1,1,1])
print(sess.run(b))
[[[2 2]
[3 3]]
[[5 5]
[6 6]]]
[[[2 2]
[3 3]]
[[5 5]
[6 6]]]
tf.strided_slice
tf.strided_slice(
input_, # input tensor
begin, # begin location
end, # end location, ATTENTION: not included!
strides=None, # strides of slice
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None
)
tf.strided_slice 1-dim operation: tf.strided_slice(input_data, [begin_element], [end_element(not included), stride_step]
import tensorflow as tf
input_data = [1,2,3,4,5,6,7,8,9]
a = tf.strided_slice(input_data, [0], [4], [1])
b = tf.strided_slice(input_data, [0], [-1], [1])
c = tf.strided_slice(input_data, [0], [-1], [2])
d = tf.strided_slice(input_data, [0], [-1])
e = tf.strided_slice(input_data, [-1], [-2], [-1])
a begin location at input_data[0], end location at input_data[4] with stride step 1
[1 2 3 4]
b begin location at input_data[0], end location at last with stride step 1
[1 2 3 4 5 6 7 8]
c begin location at input_data[0], end location at last with stride step 2
[1 3 5 7]
d begin location at input_data[0], end location at last with stride step 1 by default
[1 2 3 4 5 6 7 8]
e begin location at input_data[last element], end location at input_data[second to last] with stride step -1
[9]
tf.strided_slice higher dim operation:
import tensorflow as tf
import numpy as np
input_data = np.arange(60).reshape(3, 4, 5)
slice_np_data = input_data[1:2, 0:2, 0:2]
print(slice_np_data)
>>>
[[[20 21]
[25 26]]]
slice_tensor = tf.strided_slice(input_data,[1,0,0],[2,2,2])
with tf.Session() as sess:
print(sess.run(slice_tensor))
>>>
[[[20 21]
[25 26]]]
#3 tf.strided_slice case in seq2seq model:
In seq2seq model we slice target data into decoder network
ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['<s>']), ending], 1)
Supposing batch_size is n
. tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
keeps every batch and drop the last integer per batch. ending = targets[0:n, 0,-1]
网友评论