美文网首页
Tensorflowonspark基本示例

Tensorflowonspark基本示例

作者: 枫隐_5f5f | 来源:发表于2019-06-11 21:56 被阅读0次

用于快速将一个Tensorflow应用扩展到spark集群上进行分布式训练

一 主程序代码为

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--rdma",help="use rdma connection",default=False)
    args = parser.parse_args()

    conf = SparkConf().setAppName("ceshi")
    sc = SparkContext(conf=conf)
    num_executors = 4
    #num_executors = int(sc._conf.get("spark.executor.instances"))
    #num_executors = int(executors) if executors is not None else 1
    tensorboard = False
    num_ps = 1
    cluster = TFCluster.run(sc,main_func,args,num_executors,num_ps,tensorboard,TFCluster.InputMode.TENSORFLOW)
    cluster.shutdown()

其中parser.add_argument()可以按照需求自行添加,但--rdma为必须

二 将Tensorflow应用程序整体打包进main_func(args,ctx)

并在开头添加

cluster,server = ctx.start_cluster_server(1,args.rdma)

三 完整代码

from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFCluster,TFNode


def main_func(args,ctx):
    import numpy as np
    import  tensorflow as tf

    cluster,server = ctx.start_cluster_server(1,args.rdma)

    def add_layer(inputs,insize,outsize,activation_func=None):
      Weights = tf.Variable(tf.random_normal([insize,outsize]))
      bias = tf.Variable(tf.zeros([1,outsize])+0.1)
      wx_plus_b = tf.matmul(inputs,Weights) + bias
      if activation_func:
          return activation_func(wx_plus_b)
      else:
          return wx_plus_b

    x_data = np.linspace(-1,1,300)[:,np.newaxis]
    noise = np.random.normal(0,0.05,x_data.shape)
    y_data = np.square(x_data)  + noise

    xs = tf.placeholder(tf.float32,[None,1])
    ys = tf.placeholder(tf.float32,[None,1])

    l1 = add_layer(xs,1,10,activation_func=tf.nn.relu)
    preds = add_layer(l1,10,1,activation_func=None)

    loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - preds),reduction_indices=[1]))

    train = tf.train.GradientDescentOptimizer(0.05).minimize(loss)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      for i in range(2000):
          sess.run(train,feed_dict={xs:x_data,ys:y_data})
          if i %  2000 == 0:
              preds_val = sess.run(preds,feed_dict={xs:x_data,ys:y_data})
              print (preds_val)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--rdma",help="use rdma connection",default=False)
    args = parser.parse_args()

    conf = SparkConf().setAppName("ceshi")
    sc = SparkContext(conf=conf)
    num_executors = 4
    #num_executors = int(sc._conf.get("spark.executor.instances"))
    #num_executors = int(executors) if executors is not None else 1
    tensorboard = False
    num_ps = 1

    cluster = TFCluster.run(sc,main_func,args,num_executors,num_ps,tensorboard,TFCluster.InputMode.TENSORFLOW)

    cluster.shutdown()

四 提交到spark集群

spark-submit \
--master yarn \
--deploy-mode cluster \
--num-executors 4 \
--executor-memory 1G \
--py-files tensor_test.py \
main.py

相关文章

网友评论

      本文标题:Tensorflowonspark基本示例

      本文链接:https://www.haomeiwen.com/subject/wqocfctx.html