美文网首页
移动最小二乘法(MLS)对图像进行变形

移动最小二乘法(MLS)对图像进行变形

作者: 雨幻逐光 | 来源:发表于2020-03-27 17:20 被阅读0次

这个算法背后的思想利用了最小二乘法。首先让我们来了解下最小二乘法。

最小二乘法

最小二乘法也叫做最小平方法。它是一种可以用来估计函数参数的方法。举个简单的例子看一下这种方法的思想:
假设我们现在要测量一面墙的厚度y。经过n次测量,得到了n个数据(y_1, y_2 ...., y_n)。因为测量可能存在误差,我们应该依据何种标准得到墙体厚度的估计值呢?很多人的肯定想到了用n个数的平均值作为我们的预估值。也就是:y = \frac{\Sigma_i^ny_i}{n}。其实,这背后的理论依据就是最小二乘法。假设墙体的实际厚度是y,则n组测量值和实际值之间的误差值就是\epsilon = y - y_i。最小二乘法的思想就是要找到一个估计值y,使得所有的误差值平方和最小。即min(\Sigma_i^n(y-y_i)^2)。因此,这个问题转换成求极值的问题。将该式对y求导并等于0可得:2*\Sigma_i^n(y - y_i) * 1 = 0。我们可以求得y = \frac{\Sigma_i^ny_i}{n},即平均值。
总结而言,最小二乘法就是根据一组测量数据来预估函数的参数的方法。该方法求得的函数参数满足使得函数值在这一组测量数据中获得最小误差平方和。上面举得例子是一个常量函数,因此预估的参数就是常量值本身(墙的厚度)。

利用MLS变换图像

本文讲解的内容主要参考《Image Deformation Using Moving Least Squares》这篇论文。该算法可以通过图像上一些预设的控制点,通过移动这些控制点来控制图像的变形。
将设p为一组控制点,q是对应于p的移动后的位置(点)。现在,如果我们能找到一个变化f,使得每个图像输入v我们都能知道其变形后的对应位置f(v)。所以如何来求f呢。MLS的思想和最小二乘一样,可以简单的理解为所求目标能够最小化误差。之所以是移动的是因为每个给定的v值都对应着不同的权重值,也就对应着不同的f变换函数。具体的算法推导过程可以参考Image Deformation Using Moving Least Squares。下面我给出论文中对应项的实现(论文中共提供仿射、相似和刚性三种变换。本文仅提供仿射变换的实现)。所有的代码都没有经过优化,方便对应学习用。

权重部分
    def get_weights(self, input_pixel):
        Weights = np.zeros(self._num_cpoints)

        for i in range(self._num_cpoints):
            cpx, cpy = self._cp[i][0], self._cp[i][1]
            x, y = input_pixel[1], input_pixel[0]
            if x != cpx or y != cpy:
                weight = 1 / ((cpx - x) * (cpx - x) + (cpy - y) * (cpy - y))
            else:
                weight = self._maximum

            Weights[i] = weight

        return Weights
p* and q*
    def getPStar(self, Weights):
        numerator = np.zeros(2)
        denominator = 0
        for i in range(self._num_cpoints):
            numerator[0] += Weights[i] * self._cp[i][0]
            numerator[1] += Weights[i] * self._cp[i][1]
            denominator += Weights[I]

        return numerator / denominator

    def getQStar(self, Weights):
        numerator = np.zeros(2)
        denominator = 0
        for i in range(self._num_cpoints):
            numerator[0] += Weights[i] * self._cq[i][0]
            numerator[1] += Weights[i] * self._cq[i][1]
            denominator += Weights[I]

        return numerator / denominator
仿射变换阵
    def getTransformMatrix(self, p_star, q_star, Weights):
        sum_pwp = np.zeros((2, 2))
        sum_wpq = np.zeros((2, 2))
        for i in range(self._num_cpoints):
            tmp_cp = (np.array(self._cp[i]) - np.array(p_star)).reshape(1, 2)
            tmp_cq = (np.array(self._cq[i]) - np.array(q_star)).reshape(1, 2)

            sum_pwp += np.matmul(tmp_cp.T*Weights[i], tmp_cp)
            sum_wpq += Weights[i] * np.matmul(tmp_cp.T, tmp_cq)

        try:
            inv_sum_pwp = np.linalg.inv(sum_pwp)
        except np.linalg.linalg.LinAlgError:
            if np.linalg.det(sum_pwp) < 1e-8:
                return np.identity(2)
            else:
                raise

        return inv_sum_pwp*sum_wpq

整体流程:

import numpy as np
from skimage import io
import math


class MLSImageWarping(object):
    def __init__(self, cp, cq, whether_color_q=False, point_size=3):
        self._cp = cp
        self._cq = cq
        self._whether_color_q = whether_color_q
        self._point_size = point_size
        self._num_cpoints = len(cp)
        self._maximum = 2**31-1

    def update_cp(self, cp):
        self._cp = cp

    def update_cq(self, cq):
        self._cq = cq

    def check_is_cq(self, x, y):
        for i in range(self._num_cpoints):
            if abs(x - self._cq[i][0]) <= self._point_size and abs(y - self._cq[i][1]) <= self._point_size:
                return True
        return False

    def get_weights(self, input_pixel):
        Weights = np.zeros(self._num_cpoints)

        for i in range(self._num_cpoints):
            cpx, cpy = self._cp[i][0], self._cp[i][1]
            x, y = input_pixel[1], input_pixel[0]
            if x != cpx or y != cpy:
                weight = 1 / ((cpx - x) * (cpx - x) + (cpy - y) * (cpy - y))
            else:
                weight = self._maximum

            Weights[i] = weight

        return Weights

    def getPStar(self, Weights):
        numerator = np.zeros(2)
        denominator = 0
        for i in range(self._num_cpoints):
            numerator[0] += Weights[i] * self._cp[i][0]
            numerator[1] += Weights[i] * self._cp[i][1]
            denominator += Weights[i]

        return numerator / denominator

    def getQStar(self, Weights):
        numerator = np.zeros(2)
        denominator = 0
        for i in range(self._num_cpoints):
            numerator[0] += Weights[i] * self._cq[i][0]
            numerator[1] += Weights[i] * self._cq[i][1]
            denominator += Weights[i]

        return numerator / denominator

    def getTransformMatrix(self, p_star, q_star, Weights):
        sum_pwp = np.zeros((2, 2))
        sum_wpq = np.zeros((2, 2))
        for i in range(self._num_cpoints):
            tmp_cp = (np.array(self._cp[i]) - np.array(p_star)).reshape(1, 2)
            tmp_cq = (np.array(self._cq[i]) - np.array(q_star)).reshape(1, 2)

            sum_pwp += np.matmul(tmp_cp.T*Weights[i], tmp_cp)
            sum_wpq += Weights[i] * np.matmul(tmp_cp.T, tmp_cq)

        try:
            inv_sum_pwp = np.linalg.inv(sum_pwp)
        except np.linalg.linalg.LinAlgError:
            if np.linalg.det(sum_pwp) < 1e-8:
                return np.identity(2)
            else:
                raise

        return inv_sum_pwp*sum_wpq


    def transfer(self, data):
        row, col, channel = data.shape
        res_data = np.zeros((row, col, channel), np.uint8)

        for j in range(col):
            for i in range(row):
                input_pixel = [i, j]
                Weights = self.get_weights(input_pixel)
                p_star = self.getPStar(Weights)
                q_star = self.getQStar(Weights)
                M = self.getTransformMatrix(p_star, q_star, Weights)

                ## 逆变换版本
                try:
                    inv_M = np.linalg.inv(M)
                except np.linalg.linalg.LinAlgError:
                    if np.linalg.det(M) < 1e-8:
                        inv_M = np.identity(2)
                    else:
                        raise

                pixel = np.matmul((np.array([input_pixel[1], input_pixel[0]]) - np.array(q_star)).reshape(1, 2),
                                  inv_M) + np.array(p_star).reshape(1, 2)

                pixel_x = pixel[0][0]
                pixel_y = pixel[0][1]

                if math.isnan(pixel_x):
                    pixel_x = 0
                if math.isnan(pixel_y):
                    pixel_y = 0

                # pixel_x, pixel_y = max(min(int(pixel_x), row-1), 0), max(min(int(pixel_y), col-1), 0)
                pixel_x, pixel_y = max(min(int(pixel_x), col - 1), 0), max(min(int(pixel_y), row - 1), 0)

                if self._whether_color_q == True:
                    if self.check_is_cq(j, i):
                        res_data[i][j] = np.array([255, 0, 0]).astype(np.uint8)
                    else:
                        res_data[i][j] = data[pixel_y][pixel_x]
                else:
                    res_data[i][j] = data[pixel_y][pixel_x]

        return res_data

if __name__ == "__main__":
    ## control points
    cp = np.array([[224, 89], [302, 153], [238, 152], [191, 206]]).astype(np.float)
    cq = np.array([[179, 75], [345, 178], [218, 175], [155, 254]]).astype(np.float)

    img = io.imread("path to read")

    mls = MLSImageWarping(cp, cq, True)

    res_img = mls.transfer(img)

    io.imsave("path to write", res_img)

论文中我们知道给定v可以求得最后变化后的点位置是f(v) = (v-p^*)M+q^*。因为生成图像时,我们希望遍历图片上的每个点,因此应用时,我们希望是给定变换后的点去求原图像中对应的位置。根据上式,我们可以求得:v= f(v) - q^*M^{-1}+p^*。我们把这个变换用g表示,且变换后的点用y表示的话,那么就是g(y) = y - q^*M^{-1}+p^*。此时,g(y)就是y点对应于原来原图的点。

相关文章

网友评论

      本文标题:移动最小二乘法(MLS)对图像进行变形

      本文链接:https://www.haomeiwen.com/subject/bdequhtx.html