用于快速将一个Tensorflow应用扩展到spark集群上进行分布式训练
一 主程序代码为
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--rdma",help="use rdma connection",default=False)
args = parser.parse_args()
conf = SparkConf().setAppName("ceshi")
sc = SparkContext(conf=conf)
num_executors = 4
#num_executors = int(sc._conf.get("spark.executor.instances"))
#num_executors = int(executors) if executors is not None else 1
tensorboard = False
num_ps = 1
cluster = TFCluster.run(sc,main_func,args,num_executors,num_ps,tensorboard,TFCluster.InputMode.TENSORFLOW)
cluster.shutdown()
其中parser.add_argument()可以按照需求自行添加,但--rdma为必须
二 将Tensorflow应用程序整体打包进main_func(args,ctx)
并在开头添加
cluster,server = ctx.start_cluster_server(1,args.rdma)
三 完整代码
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from tensorflowonspark import TFCluster,TFNode
def main_func(args,ctx):
import numpy as np
import tensorflow as tf
cluster,server = ctx.start_cluster_server(1,args.rdma)
def add_layer(inputs,insize,outsize,activation_func=None):
Weights = tf.Variable(tf.random_normal([insize,outsize]))
bias = tf.Variable(tf.zeros([1,outsize])+0.1)
wx_plus_b = tf.matmul(inputs,Weights) + bias
if activation_func:
return activation_func(wx_plus_b)
else:
return wx_plus_b
x_data = np.linspace(-1,1,300)[:,np.newaxis]
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) + noise
xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])
l1 = add_layer(xs,1,10,activation_func=tf.nn.relu)
preds = add_layer(l1,10,1,activation_func=None)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - preds),reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(2000):
sess.run(train,feed_dict={xs:x_data,ys:y_data})
if i % 2000 == 0:
preds_val = sess.run(preds,feed_dict={xs:x_data,ys:y_data})
print (preds_val)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--rdma",help="use rdma connection",default=False)
args = parser.parse_args()
conf = SparkConf().setAppName("ceshi")
sc = SparkContext(conf=conf)
num_executors = 4
#num_executors = int(sc._conf.get("spark.executor.instances"))
#num_executors = int(executors) if executors is not None else 1
tensorboard = False
num_ps = 1
cluster = TFCluster.run(sc,main_func,args,num_executors,num_ps,tensorboard,TFCluster.InputMode.TENSORFLOW)
cluster.shutdown()
四 提交到spark集群
spark-submit \
--master yarn \
--deploy-mode cluster \
--num-executors 4 \
--executor-memory 1G \
--py-files tensor_test.py \
main.py
网友评论