美文网首页机器学习
load_model 如何导入自定义的loss 函数

load_model 如何导入自定义的loss 函数

作者: 泡泡_e661 | 来源:发表于2018-10-31 01:45 被阅读0次

    训练一个lstm模型,然后保存为model.h5文件,之后load_model("model.h5") 出错,错误如下

    ValueError: Unknown loss function:root_mean_squared_error

    原因:训练模型时的loss函数是自己定义的RMSE函数,如下:

    def root_mean_squared_error(y_true, y_pred):

            return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))

    在此更正一下,RMSE 函数定义应该没有axis=-1,上面那个函数应该是定义的MAE

    def root_mean_squared_error(y_true, y_pred):

            return K.sqrt(K.mean(K.square(y_pred - y_true)))

    模型编译如下:

    model.compile(optimizer = "rmsprop", loss = root_mean_squared_error, metrics =["accuracy"])

    经过网上查找,找到一个快速并且有效的解决办法,在这里和大家分享,希望可以帮助小伙伴,解决同样的issue

    需要再将root_mean_squared_error定义一遍,就是再写一遍(如何在你的script中已经存在root_mean_squared_error函数,就不需要重新定义了。我是写了两个scripts,一个用于模型训练,一个用于模型应用new data进行regression)

            def root_mean_squared_error(y_true, y_pred):

                    return K.sqrt(K.mean(K.square(y_pred - y_true)))

    然后在load_model中加入一个参数custom_objects如下:

    model = load_model('model.h5', custom_objects={'root_mean_squared_error': root_mean_squared_error})

    如果解决了您的问题,给个赞👍吧,谢谢!!!

    相关文章

      网友评论

        本文标题:load_model 如何导入自定义的loss 函数

        本文链接:https://www.haomeiwen.com/subject/kzhotqtx.html