美文网首页
djl入门-迁移学习实战之车型识别(附源码)

djl入门-迁移学习实战之车型识别(附源码)

作者: 郭彦超 | 来源:发表于2020-06-11 23:17 被阅读0次

一句话解释什么是迁移学习: 即小样本下我们也可以搞图片分类识别, 让你站在巨人的肩膀上去做你想干的事。
很多情况 下我们使用深度学习模型是要解决实际问题的,比如我们接下来要用图片分类模型去识别车型,如果从0去训练模型,不仅仅要解决大量样本标注问题,还要进行各种模型调参以使模型最优,那么有没有比较简单的方式将已有 的模型快速应用到我们的业务数据上进行预测呢?

介绍

现在我们将通过使用迁移学习的方式做一个图片分类模型,迁移学习是一种流行的深度学习技术,它可以快速将已有的高精度模型应用在其它的业务场景里,与从头训练一个模型相比,这种方式能让你快速实现一个健壮的、准确的模型
接下来我们会用ResNet18做迁移学习 去预测10款汽车,resnet是一个非常强悍的模型,它包含18层神经网络,使用ImageNet数据集经过120万张图片训练得到,支持1000个类别的识别预测

数据准备

这里将通过Jsoup爬取爱卡汽车图片数据,将爬取的数据分别放到 /data/cars/车名 目录下;将会爬取如下10款汽车

"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"

---data
   ---cars
       ---奔驰
            ---bc1.jpg
            ---bc2.jpg
            ....
       ---宝马
            ---bmw1.jpg
            ---bmw2.jpg
            ....

爬取代码:


  public static void main(String[] args) throws Exception {
        String search="http://sou.xcar.com.cn/XcarSearch/car/find/keyword/%s/pbid/none/chexiLevel/none/priceLevel/70_/sort/down/pageNO/1/limit/50?rand=1591882856094";

        String basePath = "/data/cars/";
        String[] cars = new String[]{"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"};

        for(String car: cars) {
            Document doc = Jsoup.connect(String.format(search,car))
                    .timeout(5000).get();

            //获取搜索结果页Json列表
            String bd =  doc.body().text().replace("findcar(","");
            bd = bd.substring(0,bd.length()-1);
            JSONObject json= JSON.parseObject(bd);
            
            JSONArray jsonCars = json.getJSONArray("spserList");
            for(int i=0; i<jsonCars.size();i++){
                JSONObject jsonCar = jsonCars.getJSONObject(i);
                String imgUrl = jsonCar.getString("purl");
                String persid = jsonCar.getString("persid");

                //下载图片
                File f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                FileUtil.mkParentDirs(f);
                ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);

//                ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));

                //解析详情页 获取更多外观图片数据
                doc = Jsoup.connect(String.format("http://newcar.xcar.com.cn/photo/ps%s-s_1/",persid)).timeout(4000).get();
                HtmlCleaner cleaner = new HtmlCleaner();
                //转化成TagNode
                TagNode node = cleaner.clean(doc.html());
                //通过XPath解析出图片地址
                Object[] ns2 = node.evaluateXPath("//div[@class='pic-wrap']/div[@class='pic-con']/dl/dt/a/img");
                for (Object on : ns2) {
                    TagNode n = (TagNode) on;
                    imgUrl = "http:"+n.getAttributeByName("src");
//                    ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                    f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                    ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);
                }
            }
            Thread.sleep(3000);

        }
    }

数据地址:链接: https://pan.baidu.com/s/16AHf-zJzcjWuGwokECKXYg 提取码: tqvv

下载ResNet模型

百度网盘下载resnet18,并将下载好的模型解压到/data/models/resnet


扫码下载 resnet18

重构模型

1、加载已有模型
2、移除原有模型分类输出层,并添加新的分类输出层

//加载resnet18
private Model getModel() throws IOException, MalformedModelException {
        Path modelDir = Paths.get("/data/models/resnet");
        Model model = Model.newInstance(Device.cpu(),"MXNet");
        model.load(modelDir, "resnet18_v1");
        return model;
 }
//删除resnet18的全连接层,并根据自己需要添加新的分类输出层
 private void prepareModel(Model old){
        SequentialBlock newBlock = new SequentialBlock();
        SymbolBlock block = (SymbolBlock) old.getBlock();
        block.removeLastBlock();
        newBlock.add(block);
        newBlock.add(x -> new NDList(x.singletonOrThrow().squeeze()));
        newBlock.add(Linear.builder().setOutChannels(10).build());
        newBlock.add(Blocks.batchFlattenBlock());
        old.setBlock(newBlock);
 }

模型训练

1、指定GPU个数、配置模型参数 如优化器和评估函数等
2、读取图片数据集
3、设置迭代次数进行模型训练
4、模型评估

private void train(Model model ,int epoch) throws IOException {
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
                .optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);

        for (int i = 0; i < epoch; ++i) {
            int index = 0;
            for (Batch batch : trainer.iterateDataset(getImgDataSet("train",dataPath))) {
                trainer.trainBatch(batch);
                trainer.step();
                batch.close();
            }
            // reset training and validation evaluators at end of epoch
            trainer.endEpoch();
        }
}

根据下图可以看出使用迁移学习的方式进行模型训练,仅仅执行两轮训练准度已接近70%


image.png

模型测试

1、设置Translator,使的预测数据和训练数据处理方式保持一致,并将 lable Id 映射为分类
2、将刚才训练好的模型加载到应用 程序中
3、构造预测器,并对单一或批量图片进行预测,输出分类结果

public static void predic(String imagePath) throws IOException, MalformedModelException, TranslateException {
        BufferedImage image;
        if (imagePath.startsWith("http")) {
            image = BufferedImageUtils.fromUrl(new URL(imagePath));
        } else {
            image = BufferedImageUtils.fromFile(Paths.get(imagePath));
        }

        Pipeline pipeline = new Pipeline()
                .add(new CenterCrop())
                .add(new Resize(224))
                .add(new ToTensor())
                .add(new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));
        //对图片数据进行预处理
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .setSynsetArtifactName("synset.txt")
                .optApplySoftmax(true)
                .build();

        Path modelDir = Paths.get(modelPath);
        Model model = Model.newInstance(Device.cpu(),"MXNet");
        model.load(modelDir, modelName);
        Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);
        Classifications classifications = predictor.predict(image);
        System.out.println(classifications);
}

线上部署

这里我使用spark web框架快速开发了一个图片预测Api,让大家直观感受一下迁移学习在未知场景的泛化效果

   public static void main(String[] args) {

        String repTemp = "<!DOCTYPE html><html lang=\"en\"><head>    <meta charset=\"utf-8\">    <style type=\"text/css\">        .content {            color: #ffffff;            font-size: 40px;        }        .bg {            background: url('${img}');            background-repeat: no-repeat;            background-position: center;            background-size: cover;            height:600px;            text-align: center;            line-height: 600px;        }    </style></head><body><div class=\"bg\">    <div class=\"content\">${txt}</div></div></body></html>";

        port(8899);
         
        get("/img_classes/cars/predict", (request, response) -> {
           
            return repTemp.replace("${img}",request.queryParams("img_url")).replace("${txt}",TransferLearning.predict(request.queryParams("img_url")));
            
        });
    }

查看效果

image.png

谁知道下面这辆是什么车? 下载模型跑跑试试看看

image.png

http://localhost:8899/img_classes/cars/predict?img_url=https://car3.autoimg.cn/cardfs/product/g24/M06/89/22/1024x0_1_q95_autohomecar__ChwFjl6zaqOAEtu7AAa92Uw57ys354.jpg

相关文章

网友评论

      本文标题:djl入门-迁移学习实战之车型识别(附源码)

      本文链接:https://www.haomeiwen.com/subject/ibfitktx.html