pytorch 函数理解

作者: LCG22 | 来源:发表于2021-04-27 16:02 被阅读0次

    1、torch.nn.Unfold

    函数作用:

    unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘

    参数:

    kernel_size: _size_any_t, 卷积核的大小

    dilation: _size_any_t=1, 卷积核元素之间的空洞个数

    padding: _size_any_t=0, 填充特征四周的列数,默认为 0,则不填充

    stride: _size_any_t=1,卷积核移动的步长

    函数理解:

    参考资料:

    PYTORCH实现手动滑窗,卷积(利用UNFOLD,FOLD操作)

    unfold 过程:

    ① 对于 batch 里的每个数据分别进行 unfold

    ② 分别在每个数据的每个通道上,使用大小为 k*k 的卷积核进行从左往右,从上向下的滑窗

    ③ 对于在每个通道上分别得到的第一个滑窗区域,分别进行 reshape 成行向量,然后把在所有通道上得到的行向量,进行横向拼接,得到新的行向量

    ④ 对于在每个通道上得到的滑窗区域都进行步骤 ③ 的操作,直到所有的滑窗区域都处理完

    ⑤ 将步骤 ③ 和 步骤 ④ 中得到的行向量,进行纵向拼接,得到一个矩阵

    ⑥ 完成 unfold 操作,将 batch 中每个数据进行 unfold 得到的矩阵进行堆放,得到输出结果

    例子:

    x = torch.range(1, 2*3*4*5)

    print(x.shape)

    batch_x = x.reshape([2, 3, 4, 5])

    print(batch_x.shape)

    # unfold 是展开的意思,在 torch 中则是只卷不积,相当于只滑窗,不进行元素相乘

    unfold = torch.nn.Unfold(3)

    res = unfold(batch_x)

    print(res.shape)

    结果:

    torch.Size([2, 27, 6])

    分析:

    假设输入的 batch_x 维度为 [2, 3, 4, 5],其中 2 是批的数据量大小 B, 3 是通道数 C,4 是高度 H,5 是宽度 W 。使用的卷积核大小 K 为 3*3,移动步长 S 为 1,padding 为 0

    ① 在 B 的每个数据上进行 unfold

    ② 同时在每个通道上的最左上角开始进行滑动,对于每个通道,得到大小为 9 的滑动区域,然后进行 Reshape 成维度为 [1, 9] 的行向量。然后将在所有 3 个通道上得到的 3 个行向量,进行横向拼接,得到维度为 [1, 27] 的行向量。

    ③ 依次将卷积核按照从左到右,从上往下的顺序,按照步长 1 进行滑动,每个滑动的区域经过步骤 ② 中处理后都能得到一个维度为 [1, 27] 的行向量,共得到 6 个维度为 [1, 27],然后纵向堆叠成维度为 [6, 27] 的矩阵

    ④ 将每个数据经过 unfold 得到的维度为 [6, 27]  的矩阵进行堆叠成维度为 [2, 27, 6]  的张量

    相关文章

      网友评论

        本文标题:pytorch 函数理解

        本文链接:https://www.haomeiwen.com/subject/fvjerltx.html