美文网首页
djl入门-如何在你的Java App中优雅的调用Pytorch

djl入门-如何在你的Java App中优雅的调用Pytorch

作者: 郭彦超 | 来源:发表于2020-06-09 08:58 被阅读0次

pytorch是主流的深度学习框架,不论是学术界还是工业界已有很多成熟的模型可以使用,苦无自己技术语言的壁垒,无法将他们很好的应用在自己的项目当中

Djl介绍

Djl是基于MxNet、Pytorch、TensorFlow作为backend的api框架,它屏蔽了不同模型的调用差异,用户无需了解底层框架的使用便能很快的开发出属于自己的深度学习模型

目前djl支持图像分类,图像检测,姿态预估,语义分割,Nlp模型等

将模型导入到pytorch

准备自己的模型,模型以pt结尾

1,并拷贝模型的位置替换下面代码中的模型路径

2,设置模型名字,对应模型外层文件夹名称

3,输入图片的大小要调整为模型支持的大小

4,最后一步是设置模型的分类映射文件,这个很重要,配置错误会导致预测异常或不准


public static void main(String[] args) throws Exception {

        Path modelDir = Paths.get("/Users/gxd/.djl.ai/cache/repo/model/cv/image_classification/ai/djl/pytorch/resnet/50/imagenet/0.0.1");

        Model model = Model.newInstance(Device.defaultDevice(),"PyTorch");  //MXNet

        model.load(modelDir, "traced_resnet50");



        Pipeline pipeline = new Pipeline();

        pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());

        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()

                .setPipeline(pipeline)

                .setSynsetArtifactName("synset.txt")

                .optApplySoftmax(true)

                .build();



        BufferedImage img = BufferedImageUtils.fromUrl("http://bigdata.res.yqxiu.cn/banff-4380804_960_720.jpg");

        Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);

        long s1 = System.currentTimeMillis();

        Classifications classifications = predictor.predict(img);

        System.out.printf(classifications.best().toString()+"===="+(System.currentTimeMillis()-s1));

    }



相关文章

网友评论

      本文标题:djl入门-如何在你的Java App中优雅的调用Pytorch

      本文链接:https://www.haomeiwen.com/subject/begjtktx.html