美文网首页
统一 MXNet & PyTorch & TensorFlow

统一 MXNet & PyTorch & TensorFlow

作者: 水之心 | 来源:发表于2021-02-21 23:28 被阅读0次

    为了在 MXNet、PyTorch 以及 TensorFlow 之间进行互转,本文研究它们之间的基础运算的异同。

    基础函数

    对于 MXNet 有 np(numpy)模块和 npx(numpy_extension)模块。np 模块包含了 NumPy 支持的函数。而 npx 模块包含了一组扩展函数,用来在类似 NumPy 的环境中实现深度学习开发。当使用张量时,我们几乎总是会调用 set_np 函数:这是为了兼容 MXNet 其他的张量处理组件。

    from mxnet import np, npx
    
    npx.set_np()
    

    同样,我们可以将 PyTorch 和 TensorFlow 写作 np,这样有:

    def import_np(module_name):
        if module_name == 'mxnet':
            from mxnet import np, npx
            npx.set_np()
            np.randn = np.random.randn
            return np
        elif module_name == 'torch':
            import torch as np
            np.array = np.tensor
            np.concatenate = np.cat
            return np
        elif module_name == 'tensorflow':
            from tensorflow.experimental import numpy as np
            return np
    

    这样 MXNet,TensorFlow 与 PyTorch 有相同的函数:

    MXNet TensorFlow PyTorch

    当然也有许多不同,比如求张量的大小,标准正太分布与张量定义,MXNet:

    MXNet TensorFlow PyTorch

    张量运算

    MXNet,TensorFLow 与 PyTorch 是几乎一致的。

    逐元素运算

    MXNet TensorFlow PyTorch

    张量运算

    张量的真值、元素求和、拼接:

    MXNet TensorFlow PyTorch

    仔细观察也可以看到细微的不同,np.arange 的数据类型的默认值不同,故而,建议:定义张量最好也把 dtype 也指定了。

    广播机制

    MXNet TensorFlow PyTorch

    索引和切片

    MXNet TensorFlow PyTorch

    节省内存

    运行一些操作可能会导致为新结果分配内存。例如,如果我们用 Y = X + Y,我们将取消引用 Y 指向的张量,而是指向新分配的内存处的张量。下面展示了节省内存的方法:

    MXNet & PyTorch TensorFlow

    __array__ 的妙用

    由于 MXNet,TensorFLow 与 PyTorch 均实现了 __array__,故而,它们均可直接传入 NumPy 数据到张量。下面仅以 TensorFlow 为例:

    这样的好处是,可以直接使用 Matplotlib 画图:

    BUG

    nvidia-smi指令报错:Failed to initialize NVML: Driver解决 - 知乎 (zhihu.com)

    相关文章

      网友评论

          本文标题:统一 MXNet & PyTorch & TensorFlow

          本文链接:https://www.haomeiwen.com/subject/yjfcfltx.html