tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, infer_shape=True, name=None)
具体参数定义可以参见官网,这里捡最常见的形式来说,常用函数形式为tf.scan(fn, elems, initializer=None)
函数说明:scan on the list of tensors unpacked from elems on dimension 0.
f(n)以elems的第一维度的变量list作函数计算直到遍历完整个elems
关于 initializer的说明为:
If no initializer is provided, the output structure and dtypes of fn are assumed to be the same as its input; and in this case, the first argument of fn must match the structure of elems.
If an initializer is provided, then the output of fn must have the same structure as initializer; and the first argument of fn must match this structure.
也就是说当initializer给定的时候,fn的输出结构必须和initializer保持一致,且fn的第一个参变量也必须和该结构一致。而如果该参数没有给定的时候初始化默认和x[0]的维度保持一致。
设函数为f,
x = [u(0),u(1),...,u(n)]
y = tf.scan(f,x,initializer=v(0))
此时f的参数类型必须是(v(0),x),f的输出必须和v(0)保持一致,整个计算过程如下:
v(1)=f(v(0),u(0))
v(2)=f(v(1),u(1))
....
v(n+1)=f(v(n),u(n))
y=v(n+1)
网友评论