pytorch是主流的深度学习框架,不论是学术界还是工业界已有很多成熟的模型可以使用,苦无自己技术语言的壁垒,无法将他们很好的应用在自己的项目当中
Djl介绍
Djl是基于MxNet、Pytorch、TensorFlow作为backend的api框架,它屏蔽了不同模型的调用差异,用户无需了解底层框架的使用便能很快的开发出属于自己的深度学习模型
目前djl支持图像分类,图像检测,姿态预估,语义分割,Nlp模型等
将模型导入到pytorch
准备自己的模型,模型以pt结尾
1,并拷贝模型的位置替换下面代码中的模型路径
2,设置模型名字,对应模型外层文件夹名称
3,输入图片的大小要调整为模型支持的大小
4,最后一步是设置模型的分类映射文件,这个很重要,配置错误会导致预测异常或不准
public static void main(String[] args) throws Exception {
Path modelDir = Paths.get("/Users/gxd/.djl.ai/cache/repo/model/cv/image_classification/ai/djl/pytorch/resnet/50/imagenet/0.0.1");
Model model = Model.newInstance(Device.defaultDevice(),"PyTorch"); //MXNet
model.load(modelDir, "traced_resnet50");
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.setSynsetArtifactName("synset.txt")
.optApplySoftmax(true)
.build();
BufferedImage img = BufferedImageUtils.fromUrl("http://bigdata.res.yqxiu.cn/banff-4380804_960_720.jpg");
Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);
long s1 = System.currentTimeMillis();
Classifications classifications = predictor.predict(img);
System.out.printf(classifications.best().toString()+"===="+(System.currentTimeMillis()-s1));
}
网友评论