美文网首页
torch.meshgrid()和np.meshgrid()的区

torch.meshgrid()和np.meshgrid()的区

作者: 运动小爽 | 来源:发表于2020-02-08 18:43 被阅读0次

    np.meshgrid()函数常用于生成二维网格,比如图像的坐标点。
    pytorch中也有一个类似的函数torch.meshgrid(),功能也类似,但是两者的用法有区别,使用时需要注意(刚踩坑,因此记录一下。。。)

    比如我要生成一张图像(h=6, w=10)的xy坐标点,看下两者的实现方式:

    np.meshgrid()

    >>> import numpy as np
    >>> h = 6
    >>> w = 10
    >>> xs, ys = np.meshgrid(np.arange(w), np.arange(h))
    >>> xs.shape
    (6, 10)
    >>> ys.shape
    (6, 10)
    >>> xs
    array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    >>> ys
    array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
           [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
           [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
           [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
           [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
           [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])
    >>> xys = np.stack([xs, ys], axis=-1)
    >>> xys.shape
    (6, 10, 2)
    

    torch.meshgrid()

    >>> import torch
    >>> h = 6
    >>> w = 10
    >>> ys,xs = torch.meshgrid(torch.arange(h), torch.arange(w))
    >>> xs.shape
    torch.Size([6, 10])
    >>> ys.shape
    torch.Size([6, 10])
    >>> xs
    tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    >>> ys
    tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
            [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
            [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]])
    >>> xys = torch.stack([xs, ys], dim=-1)
    >>> xys.shape
    torch.Size([6, 10, 2])
    

    从python交互式窗口可以清晰的看出numpy和pytorch中meshgrid()函数的区别,就不用文字总结了,自己体会哈哈哈。

    相关文章

      网友评论

          本文标题:torch.meshgrid()和np.meshgrid()的区

          本文链接:https://www.haomeiwen.com/subject/otoixhtx.html