A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks 论文读书报告
简介
本文将异常样本点的检查细分为两类: OOD样本点(out-of-distribution samples)和 对抗样本(adversarial samples)。文中提出了一种能同时检查这两类异常样本的方式,并且可以被应用到所有已经经过预训练的softmax深度神经网络中。该方法使用 GDA (Gaussian discriminant analysis)来对特征分布进行建模预估,然后利用马氏距离(Mahalanobis distance)来计算得分,距离越远说明其离正常样本分布更远。该方法能达到当前 state of art 的性能。并且相比于之前提出的方法,他们都只针对异常样本点中其中一类进行设计,而本文中提出的方法可以同时对这两类异常点都进行检测。
GDA(Gaussian discriminant analysis)
对p(x|y)建模
它要解决分类问题,此例中y只取0或1,x取连续量。
对此模型的定义:
data:image/s3,"s3://crabby-images/2fa83/2fa8371b3a616a1c112b7c86ba5678caa053aa5e" alt=""
y显然是伯努利分布,这里我们假设了x服从多维正态分布。
具体展开上面的概率分布:
data:image/s3,"s3://crabby-images/141eb/141eb57e1ed21d3b0a23249b3a03fbe27fe88c15" alt=""
我们依然用最大化似然函数的方法得到各参数的最优值:(与前面稍有不同,前面我们用p(y|x)来得出似然函数,但这里我们为了简化计算,采用p(x,y)得出,事实上,两者同时取最值。)
data:image/s3,"s3://crabby-images/bae35/bae354043d11aca8643842bd9928fef514972856" alt=""
将概率分布代入上式,最大化这个函数,即可得到各个参数u0,u1,等,也即对输入特征的分布进行了建模。
马氏距离
用来度量一个样本点P与数据分布为D的集合的距离。
假设样本点为:
data:image/s3,"s3://crabby-images/93f30/93f3043a5fb6730caafd62529d4e72f52e60380e" alt=""
数据集分布的均值为:
data:image/s3,"s3://crabby-images/ac311/ac311f1916fe1030e1b8c49b7334625cc5e49791" alt=""
协方差矩阵为S。
则这个样本点P与数据集合的马氏距离为:
data:image/s3,"s3://crabby-images/39f5f/39f5f14e58197956e7c7b485ed355eb57fbcebdc" alt=""
马氏距离也可以衡量两个来自同一分布的样本x和y的相似性:
data:image/s3,"s3://crabby-images/2972e/2972ecf28910c1c2cf31c6a5611be8b620f37d3b" alt=""
当样本集合的协方差矩阵是单位矩阵时,即样本的各个维度上的方差均为1.马氏距离就等于欧式距离相等。
具体实现
取 DNN 倒数第二层也就是 softmax层之前的输出设为 f(x)
Mahalanobis distance-based confidence score
data:image/s3,"s3://crabby-images/663e7/663e78d8983441c3c30d8c24f806abd3ea2b8310" alt=""
其中如下公式用于提供计算马氏距离参数。这等价于在最大似然估计下训练样本具有联合协方差的条件高斯分布。:
data:image/s3,"s3://crabby-images/524f7/524f7183afbd09acdfc76c3fb320a48ee8749071" alt=""
优化方法:
data:image/s3,"s3://crabby-images/abbab/abbabe06dd7d67e5548fd9c0fd4c8a3a6fab8c1c" alt=""
简单的说,该方法首先利用 GDA 生成模型代替 softmax 构建后验概率分布,然后最小化马氏距离,计算每个样本的 Mahalanobis distance-based confidence score,得分越高说明在该分布中的概率越大。
改进
文中,作者又对以上方式进行了扩展,不仅仅使用最后一层DNN输出,而是使用所有DNN层的输出,分别计算 score ,然后使用一个简单的logistic回归进行分类设置权重。伪代码如下:
data:image/s3,"s3://crabby-images/3c99e/3c99e3cef17fd6d9e6f160c287c9b77b98ce3ca2" alt=""
实验结果
下图是文中提出的检查方法和其他方法的对比结果:
data:image/s3,"s3://crabby-images/0fa4e/0fa4ef1e2eab0ec45e7fe3c742f018146e29865f" alt=""
下图是在两类异常对抗样本上的对比结果:
data:image/s3,"s3://crabby-images/0d607/0d607ccaae9fbc9210df2ab557c9049d1009be2b" alt=""
网友评论