美文网首页
两步实现图片搜索(深度学习以图搜图实战)

两步实现图片搜索(深度学习以图搜图实战)

作者: 郭彦超 | 来源:发表于2021-04-30 13:36 被阅读0次

    [易企秀] 模板商城每天都会上架大量的模板商品,这些商品大部分都是合规的,但也存在模板抄袭的情况,为保护作者权益,每天需要大量人工审核团队进行把关,这种方式会因人不同而标准不同,且效率比较低,而传统的特征提取方案又存在精度问题,那么迫切需要一种高效的方案来应对这种场景。

    简述

    下面讲解一下如何通过ES实现图片相似性检索,ES在7.3之后添加了dense_vevtor数据类型,用来支持特征相似性检索(需要注意的是目前版本的ES最大支持的存储长度为2048),我们将通过keras进行特征提取,并将提取的特征存储在该类型下用于之后的相似性检索


    特征提取(一)

    深度学习特征提取一般用CNN, CNN 能干什么?CNN是图像卷积网络,可以提取特征,进而识物,我把这个过程简单的理解为,从多个不同的维度去提取特征,衡量一张图片的内容或者特征与猫的特征有多接近,与狗的特征有多接近,等等等等,选择最接近的就可以作为我们的识别结果,也就是判断这张图片的内容是猫,还是狗,还是其它。

    CNN 识物又跟我们找相似的图像有什么关系?我们要的不是最终的识物结果,而是从多个维度提取出来的特征向量,两张内容相似的图像的特征向量一定是接近的。

    这里我们使用CNN中经典代表网络Vgg16, 一般的我们会使用网络的n-1层或n-2层作为特征输出,该层是全连接层,特征长度4096

    from keras.applications.vgg16 import VGG16
    from keras.preprocessing import image
    from keras.applications.vgg16 import preprocess_input
    from keras.models import Model
    import numpy as np
    import pylab
    
    base_model = VGG16(weights='imagenet')  #这里也可以使用自己的数据集进行训练
    #base_model.summary()
    #查看VGG16网络结构,我们需要提取到fc2层的4096维特征输出 
    model = Model(inputs=base_model.input, outputs=base_model.get_layer('fc2').output)  #获取全连接层结果输出
    img_path = '1.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    features = model.predict(x)
    
    print(features.shape)
    #pylab.imshow(features[0][:,:,0])
    #pylab.show()
    

    如何对fc2层特征进行降维

    1、最值法,将4096维特征reshape为(1024,4),并取行维最大值

    import numpy as np
    array = features.reshape(1024, 4)
    array = np.max(array, axis=1)
    print(array.shape, list(array))
    

    2、均值法,类似卷积网络中的池化层

    import numpy as np
    array = features.reshape(1024, 4)
    array = np.average(array, axis=1)
    print(array.shape, list(array))
    

    3、 迁移学习:由于第一种提取方式采集到的图片特征是4096超出了dense_vector支持的最大长度范围,所以需要自定义模型输入输出层大小,这里获取长度1024,需要注意的是新加入层需要先对model进行训练

    from keras.applications.vgg16 import VGG16
    from keras.preprocessing import image
    from keras.layers import Dense, Flatten, Dropout
    from keras.applications.vgg16 import preprocess_input
    from keras.models import Model
    import numpy as np
    #定义输入大小
    ishape = 64
    base_model =  VGG16(include_top=False, weights='imagenet', input_shape=(ishape,ishape, 3))
    #include_top=False 表示将vgg16顶层去掉,只保留网络结构
    for layers in base_model.layers:
        layers.trainable = False  #layers.trainable = False将不需要重新训练的权重“冷冻”起来
      
    model = Flatten()(base_model.output)
    model = Dense(4096, activation='relu',name='fc1')(model)
    model = Dense(1024, activation='relu',name='fc2')(model)
    # model = Dropout(0.25)(model)
    # model = Dense(10, activation='softmax',name='prediction')(model)
    #通过Dropout将vgg16的全连接层
    model = Model(inputs=base_model.input, outputs=model, name='vgg16_pretrain')
    #print(model.summary()) 
    img_path = 'testKoala.jpg'
    img = image.load_img(img_path, target_size=(64, 64))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    features = model.predict(x)
    
    print(features.shape)
    list(features[0])
    

    相似检索(二)

    借助es实现相似检索

    • 存储格式如下
    {
      "mappings": {
        "properties": { 
              "title": {
                "type": "text"
              },
              "features": {
                "type": "dense_vector",
                "dims": 1024
              }
         } 
      }
    }
    
    • 图片相似性检索计算
    {
      "query": {
        "script_score": {
          "query": {
            "match_all": {}
          },
          "script": {
            "source": "1 / (l2norm(params.queryVector,  'features') + 1)",
            "params": {
              "queryVector": [0.1, 0.2, 0.4, 0.1, 0.0, ...]
            }
          }
        }
      }
    }
    

    vgg16网络层级图


    补充一

    上面提到的是易企秀这边的实现方案,下面补充一个阿里的做法 :

    当时阿里云这边采用的es6,并不支持vector,不过特征提取部分我们可以学习和借鉴一下,模型提取的特征是 512 维,维度适中,如果维度太少,精度可能会受影响,如果维度太多,存储和计算这些特征向量的成本会比较高。

    from keras.applications.vgg16 import VGG16
    from keras.preprocessing import image
    from keras.applications.vgg16 import preprocess_input
    from keras.models import Model
    import numpy as np
    import pylab
    import requests
    import io
    from PIL import Image
    from numpy import linalg as LA
    
    def get_image_feature(url):
        base_model = VGG16(weights='imagenet', input_shape = ( 224, 224, 3), pooling='max',include_top=False)  #可以看出阿里使用的是池化层结果输出 
        
        #查看VGG16网络结构,我们需要提取到fc2层的4096维特征输出 
        #model = Model(inputs=base_model.input, outputs=base_model.get_layer("fc2").output)  #或者使用model.layers[3].output获取对应的网络输出层
        model = base_model
        #print(model.summary())
        
        # img_path = 'jd/test.jpg'
        # img = image.load_img(img_path, target_size=(224, 224))
        # x = image.img_to_array(img)
        response = requests.get(url)
        img_bytes = io.BytesIO(response.content)
        img = Image.open(img_bytes)
        img = img.convert('RGB')
        img = img.resize((224,224), Image.NEAREST)
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        x = preprocess_input(x)
        features = model.predict(x)
        return features[0]
    
    features = get_image_feature("http://res2.eqh5.com/Fpc8E_za0j_13QCFh-cAjexwHTPv?imageMogr2/format/webp/quality/80/thumbnail/270x")
    
    vec = features/LA.norm(features)
    print(vec.shape,vec.tolist())
    

    补充二

    使用keras加载TensorFlow特征提取模型

    import tensorflow as tf
    import tensorflow_hub as hub
      
    import numpy as np
    import requests
    import io
    from PIL import Image
    
    def get_image_feature(url):
    #     base_model = VGG16(weights='imagenet')  #这里也可以使用自己的数据集进行训练
        base_model = tf.keras.Sequential([
            hub.KerasLayer("C:/Users/bigdata/Downloads/imagenet_mobilenet_v2_100_224_feature_vector_5",
                           trainable=False) 
        ])
        base_model.build([None, 224, 224, 3])  # Batch input shape.
        print(base_model.summary())
        model = base_model
      
        response = requests.get(url)
        img_bytes = io.BytesIO(response.content)
        img = Image.open(img_bytes)
        img = img.convert('RGB')
        img = img.resize((224,224), Image.NEAREST)
        x = image.img_to_array(img)
        x = np.expand_dims(x, axis=0)
        #x = preprocess_input(x)
        features = model.predict(x)
        return features[0]
    
    features = get_image_feature("http://res1.eqh5.com/FhtC-ClVdA6IX772G82EtUAre9IE?imageMogr2/format/webp/quality/80/thumbnail/270x")
    print(features.shape,list(features))
    

    补充三

    • 使用DJL进行深度学习特征抽取
      需要注意的是使用djl进行特征抽取需要开代理,否则一些模型文件可能无法下载
      模型下载
     
    package ai.djl.examples.inference.face;
    
    import ai.djl.ModelException;
    import ai.djl.engine.Engine;
    import ai.djl.inference.Predictor;
    import ai.djl.modality.cv.Image;
    import ai.djl.modality.cv.ImageFactory;
    import ai.djl.modality.cv.transform.Normalize;
    import ai.djl.modality.cv.transform.ToTensor;
    import ai.djl.modality.cv.util.NDImageUtils;
    import ai.djl.ndarray.NDArray;
    import ai.djl.ndarray.NDList;
    import ai.djl.ndarray.NDManager;
    import ai.djl.repository.zoo.Criteria;
    import ai.djl.repository.zoo.ModelZoo;
    import ai.djl.repository.zoo.ZooModel;
    import ai.djl.training.util.ProgressBar;
    import ai.djl.translate.*;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.io.IOException;
    import java.nio.file.Path;
    import java.nio.file.Paths;
    import java.util.Arrays;
    
    public final class ImgFeatureExtraction {
    
        private static final Logger logger = LoggerFactory.getLogger(ImgFeatureExtraction.class);
    
        private ImgFeatureExtraction() {}
    
        public static void main(String[] args) throws IOException, ModelException, TranslateException {
            System.out.println(Engine.getInstance().getEngineName());
            if (!"TensorFlow".equals(Engine.getInstance().getEngineName())) {
                logger.info("This example only works for PyTorch.");
                return;
            }
    
            Path imageFile = Paths.get("src/test/resources/kana1.jpg");
            Image img = ImageFactory.getInstance().fromFile(imageFile);
    
            float[] feature = ImgFeatureExtraction.predict(img);
            System.out.println(feature.length);
            if (feature != null) {
                logger.info(Arrays.toString(feature));
            }
        }
    
        public static float[] predict(Image img)
                throws IOException, ModelException, TranslateException {
            img.getWrappedImage();
            Criteria<Image, float[]> criteria =
                    Criteria.builder()
                            .setTypes(Image.class, float[].class)
                            .optModelPath(Paths.get("C:\\Users\\bigdata\\.djl.ai\\cache\\repo\\model\\undefined\\ai\\djl\\localmodelzoo\\fb7a319b76d37104079cb5e425ce8002"))
    //                        .optModelUrls("https://storage.googleapis.com/tfhub-modules/tensorflow/resnet_50/feature_vector/1.tar.gz")
    //                        .optModelName("tf2-preview_mobilenet_v2_feature_vector_4") // specify model file prefix
                            .optTranslator(new FaceFeatureTranslator())
                            .optProgress(new ProgressBar())
                            .optEngine("TensorFlow") // Use PyTorch engine
                            .build();
    
            try (ZooModel<Image, float[]> model = ModelZoo.loadModel(criteria)) {
    
                Predictor<Image, float[]> predictor = model.newPredictor();
                long t1 = System.currentTimeMillis();
                float[] re = predictor.predict(img);
                long t2 = System.currentTimeMillis();
                System.out.println(t2-t1);
                return re;
            }
        }//0.047473587, 0.09630219, 0.13978973, 0.0, 0.32955706
    
        private static final class FaceFeatureTranslator implements Translator<Image, float[]> {
    
            FaceFeatureTranslator() {}
    
            /** {@inheritDoc} */
            @Override
            public NDList processInput(TranslatorContext ctx, Image input) {
                NDManager manager = ctx.getNDManager();
                NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
                array = NDImageUtils.resize(array, 224).div(255.0f);
                return new NDList(array);
            }
    
            /** {@inheritDoc} */
            @Override
            public float[] processOutput(TranslatorContext ctx, NDList list) {
                NDList result = new NDList();
                long numOutputs = list.singletonOrThrow().getShape().get(0);
                for (int i = 0; i < numOutputs; i++) {
                    result.add(list.singletonOrThrow().get(i));
                }
                float[][] embeddings =
                        result.stream().map(NDArray::toFloatArray).toArray(float[][]::new);
                float[] feature = new float[embeddings.length];
                for (int i = 0; i < embeddings.length; i++) {
                    feature[i] = embeddings[i][0];
                }
                return feature;
            }
    
            /** {@inheritDoc} */
            @Override
            public Batchifier getBatchifier() {
                return Batchifier.STACK;
            }
        }
    }
    
    

    优化

    • 通过测试发现单张图片特征提取在0.5s左右,随着图片质量的变化会上下波动,为了确保性能可以在图片特征提取前对原始图片进行质量压缩

    • 对原彩图添加滤镜处理可提升识别精度

    
    img = img.filter(ImageFilter.CONTOUR)  # 设置图片轮廓筛选器
    #或者
    img = img.filter(ImageFilter.MaxFilter)  #
    
    

    回顾

    到这里不难发现 整个流程只有两大部分 CNN + 向量化引擎

    相关文章

      网友评论

          本文标题:两步实现图片搜索(深度学习以图搜图实战)

          本文链接:https://www.haomeiwen.com/subject/veunrltx.html