前言
在比较大规模的iOS项目开发中,会遇到这样的场景,一个新需求使用的icon可能之前有,但是想找到实在是太难了。最近在学PyTorch,于是想到是否能用PyTorch做一个本地的图片搜索功能,经过一番搜索,非常可行
项目代码
ImageSearchApp: ImageSearchApp (gitee.com)使用QT做了一个简单的UI,安装好依赖,运行image_search_app.py
即可。目前代码未作整理,仅供参考
上图搜索到的是VOC2007图库中的汽车,在实际iOS项目中我也有尝试,比如搜索叉号icon,在项目中找到了11个叉号图片。
Python依赖
- PyQt5
- pyperclip
- torch
- torchvision
- PIL
- PySide2
图片搜索原理
图片搜索主要分为以下几步
- 构建被搜索图片的特征库
- 提取输入图片的特征
- 将输入图片的特征和被搜索图片的特征进行比对,得出最相近的top K个结果
构建被搜索图片的特征库
抽取特征
直接基于PreTrain的模型进行特征抽取,我的代码中使用了resnet-18
模型avgpool
模块的特征输出,输出尺寸为512
model = models.resnet18(pretrained=True)
layer = model._modules.get('avgpool')
通过注册hook获取特征输出
image = self.normalize(self.toTensor(img)).unsqueeze(0).to(self.device)
embedding = torch.zeros(1, self.number_features, 1, 1)
def copy_data(m ,i, o): embedding.copy_(o.data)
h = self.feature_layer.register_forward_hook(copy_data)
self.model(image)
h.remove()
特征抽取的完整代码在feature_extrator.py
中
特征持久化
为了避免每次都得重新计算特征,我使用了h5py保存特征值,使用图片文件的路径md5作为主key,分别保存path
和feature
值
h5_base_key = self.md5_of_path(img_full_path)
path_data = dbfile.create_dataset(h5_base_key + '/path', (1), dtype=h5py.special_dtype(vlen=str))
path_data[:] = img_full_path
dbfile[h5_base_key + "/feature"] = feature
这块的完整代码在batch_feature_processor.py
中
提取输入图片的特征
输入图片的特征提取直接使用feature_extrator.py
即可
特征比对
比对主要使用余弦相似度
来评估图片特征向量的相似度。在二维空间,余弦相似度可以理解为两个向量的夹角,夹角为0时,相似度最高,此时余弦为1,余弦的计算公式如下
cos(angle) = dot(VecA, VecB) / (|VecA| * |VecB|)
这个公式在高维度同样适用,比如我们输出的特征向量,是512维,计算代码如下
np.inner(feature_a.T, feature_b.T) / ((np.linalg.norm(feature_a, axis=0).reshape(-1, 1)) * ((np.linalg.norm(feature_b, axis=0).reshape(-1,1)).T))
np.inner
表示内积,在高维空间,使用内积计算向量的点乘。np.linalg.norm
则是计算第二范数,对应到二维空间就是计算长度。转置T是为了让矩阵的Shape匹配。
通过比对余弦值的大小就可以得到最匹配的图片啦~
网友评论