美文网首页
tf.py_func灵活操作tensor

tf.py_func灵活操作tensor

作者: 枫丫头爱学习 | 来源:发表于2019-03-02 20:33 被阅读0次

tensorflow中所有的tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的,不能对其进行判值操作,如if ... else...等,在实际问题中,我们可能需要将一个tensor转换成numpy array 然后进行一些 np.的运算,然后返回tensor. 这样可以加强tensorflow的灵活性。

tf.py_func


其中func函数以numpy arrays 作为输入(或placeholder 需要feed),并以numpy arrays 作为输出。在函数中可自由使用针对 numpy arrays 的操作。

解释一下参数:
func: 是用户自定义函数,输入是numpy array 输出是numpy array
inp: 是func函数接受的输入,是一个列表
Tout: 指定numpy转化为tensor 后的形式

tf.py_func 返回值是一个tensor

注意:

tf.py_func中的func是脱离Graph的。在func中不能定义可训练的参数参与网络训练(反向传播)。

举个例子:

import tensorflow as tf

def add(x,y):
     return x+y,x-y,x.dot(y)

a = [[1,2],[3,4]]
b = [[1,2],[1,1]]
x = tf.placeholder(tf.float32,(2,2))
y = tf.placeholder(tf.float32,(2,2))
result1,result2,result3 = tf.py_func(add, [x,y], [tf.float32,tf.float32,tf.float32])

with tf.Session as sess:
    sess.run(tf.global_varbles_initializer())
    s1,s2,s3 = sess.run([result1,result2,result3],feed_dict = {x:a,y:b})
    print(s1,s2,s3)

相关文章

网友评论

      本文标题:tf.py_func灵活操作tensor

      本文链接:https://www.haomeiwen.com/subject/rieyuqtx.html