美文网首页
2024-01-23-pyspark调用tf2.0模型进行分布式

2024-01-23-pyspark调用tf2.0模型进行分布式

作者: 破阵子沙场秋点兵 | 来源:发表于2024-01-22 10:57 被阅读0次

代码来自-github

代码

Load Dependencies

import tensorflow as tf
from pyspark import SparkFiles
from pyspark.sql.functions import udf
import pyspark.sql.types as T
from pyspark.sql import Row
print(tf.__version__)

Fetch SavedModel from S3/GCS and Distribute to Nodes

S3_PREFIX = "s3://"

MODEL_BUCKET = "my-models-bucket"
MODEL_PATH = "path/to/my/model/dir"
MODEL_NAME = "model"

S3_MODEL = f"{S3_PREFIX}{MODEL_BUCKET}/{MODEL_PATH}/{MODEL_NAME}"

print("Fetching model", S3_MODEL)

# Add model to all workers
spark.sparkContext.addFile(S3_MODEL, recursive=True)

Create the Input Dataframe

# In this example, the SavedModel has the following format:

# inputs = tf.keras.Input(shape=(784,), name='img')
# x = layers.Dense(64, activation='relu')(inputs)
# x = layers.Dense(64, activation='relu')(x)
# outputs = layers.Dense(10, activation='softmax')(x)
# model = tf.keras.Model(inputs=inputs, # outputs=outputs, name='mnist_model')

(_, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_test = x_test.reshape(10000, 784).astype('float32') / 255

rows = list(map(lambda n: Row(img=[n.tolist()]), x_test))

schema = T.StructType([T.StructField('img',T.ArrayType(T.ArrayType(T.FloatType())))])

input_df = spark.createDataFrame(rows, schema=schema)

Memoize Retrieval of the Saved Model

# Simple memoization helper with a single cache key
def compute_once(f):
    K = '0'
    cache = {}
    
    def wrapper(x):
        # Set on first call
        if K not in cache:
            cache[K] = f(x)
        
        return cache[K]

    return wrapper
    

def load_model(model_name):
    # Models are saved under the SparkFiles root directory
    root_dir = SparkFiles.getRootDirectory()
    export_dir = f"{root_dir}/{model_name}"
    
    return tf.saved_model.load(export_dir, tags=['serve'])
    

# Only load the model once per worker!
# The reduced disk IO makes prediction much faster
memo_model_load = compute_once(load_model)

def get_model_prediction(model_name, input):
    """
    Note: 
        TF session is scoped to where the model is loaded.
        All calls to the model's ConcreteFunciton must be in the same scope as
        the loaded model (i.e in the same function!)
        
        If not, TF will throw errors for undefined/ variables
    """
    # Load the predict function (from disk or cache)
    m = memo_model_load(model_name)
    
    # Save the predict signature
    pred_func = m.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
    
    return pred_func(input)

Create the Predict UDF

# Decorator with return type of UDF
@udf("array<array<float>>")
def infer(data):
    # Cast the input to a Tensor
    input_data = tf.constant(data)
    
    # Returns a dict of the form { TENSOR_NAME: Tensor }
    outputs = get_model_prediction(MODEL_NAME, input_data)

    # Assuming we have a single output
    output_tensor = list(outputs.values())[0]
    
    # Convert back to regular python
    output_value = output_tensor.numpy().tolist()
    
    return output_value

Infer on the Dataset 🎉

Infer on the Dataset 🎉

## 这里其实更建议使用mapPartiton的方式,速度会更快
predictions_df = input_df.withColumn("predictions", infer("img"))

# All done :) 
predictions_df.show(vertical=True)

相关文章

网友评论

      本文标题:2024-01-23-pyspark调用tf2.0模型进行分布式

      本文链接:https://www.haomeiwen.com/subject/xvyyodtx.html