6 Module -庖丁解牛之pytorch

作者: readilen | 来源:发表于2018-10-24 00:02 被阅读1次

    Module存储了模块类的函数

    pytorch中模块非常容易使用,只需要派生自Module,重载两个函数就行了,那么Module都做了什么

    class Module(object):
      def __init__(self):
            self._backend = thnn_backend
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
            self._modules = OrderedDict()
            self.training = True
    

    构造函数生成一堆有序字典,用来存储各种参数,暂且不表,先说第一个结构self._backend是一个全局THNNFunctionBackend()类,存储一个一系列函数指针, 这个类派生类是FunctionBackend

    class FunctionBackend(object):
        def __init__(self):
            self.function_classes = {}
        def register_function(self, name, function_class):
            self.function_classes[name] = function_class
    

    其中这个类的function_classes字典的键是名称,值是函数,使用register_function添加注册,注册完毕后约有118个函数,本文的pytorch版本是0.4.1

    RNN                                      <function RNN at 0x7f4330534378>
    RNNTanhCell                              <function RNNTanhCell at 0x7f4330530d90>
    RNNReLUCell                              <function RNNReLUCell at 0x7f43305309d8>
    LSTMCell                                 <function LSTMCell at 0x7f4330530e18>
    GRUCell                                  <function GRUCell at 0x7f4330530ea0>
    Dropout                                  <class 'torch.nn._functions.dropout.Dropout'>
    Dropout2d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
    Dropout3d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
    MarginCriterion                          <class 'torch.nn._functions.thnn.auto.MarginCriterion'>
    MarginCriterionBackward                  <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>
    GatedLinear                              <class 'torch.nn._functions.thnn.auto.GatedLinear'>
    GatedLinearBackward                      <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>
    SpatialFullConvolutionMap                <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>
    SpatialFullConvolutionMapBackward        <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>
    VolumetricFractionalMaxPooling           <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>
    VolumetricFractionalMaxPoolingBackward   <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>
    VolumetricFullDilatedConvolution         <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>
    VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>
    Col2Im                                   <class 'torch.nn._functions.thnn.auto.Col2Im'>
    Col2ImBackward                           <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>
    DilatedConv2d                            <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>
    DilatedConv2dBackward                    <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>
    SpatialConvolutionLocal                  <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>
    SpatialConvolutionLocalBackward          <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>
    FeatureLPPooling                         <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>
    FeatureLPPoolingBackward                 <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>
    VolumetricGridSamplerBilinear            <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>
    VolumetricGridSamplerBilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>
    TemporalUpSamplingNearest                <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>
    TemporalUpSamplingNearestBackward        <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>
    SpatialUpSamplingNearest                 <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>
    SpatialUpSamplingNearestBackward         <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>
    ReflectionPad1d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>
    ReflectionPad1dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>
    SpatialConvolutionMap                    <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>
    SpatialConvolutionMapBackward            <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>
    NLLLoss                                  <class 'torch.nn._functions.thnn.auto.NLLLoss'>
    NLLLossBackward                          <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>
    Softplus                                 <class 'torch.nn._functions.thnn.auto.Softplus'>
    SoftplusBackward                         <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>
    LogSigmoid                               <class 'torch.nn._functions.thnn.auto.LogSigmoid'>
    LogSigmoidBackward                       <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>
    SpatialUpSamplingBilinear                <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>
    SpatialUpSamplingBilinearBackward        <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>
    ReplicationPad3d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>
    ReplicationPad3dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>
    MultiMarginLoss                          <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>
    MultiMarginLossBackward                  <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>
    ReplicationPad1d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>
    ReplicationPad1dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>
    MultiLabelMarginLoss                     <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>
    MultiLabelMarginLossBackward             <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>
    SpatialFullDilatedConvolution            <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>
    SpatialFullDilatedConvolutionBackward    <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>
    SoftMarginLoss                           <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>
    SoftMarginLossBackward                   <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>
    NLLLoss2d                                <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>
    NLLLoss2dBackward                        <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>
    MSELoss                                  <class 'torch.nn._functions.thnn.auto.MSELoss'>
    MSELossBackward                          <class 'torch.nn._functions.thnn.auto.MSELossBackward'>
    Sigmoid                                  <class 'torch.nn._functions.thnn.auto.Sigmoid'>
    SigmoidBackward                          <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>
    VolumetricUpSamplingTrilinear            <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>
    VolumetricUpSamplingTrilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>
    BCELoss                                  <class 'torch.nn._functions.thnn.auto.BCELoss'>
    BCELossBackward                          <class 'torch.nn._functions.thnn.auto.BCELossBackward'>
    Square                                   <class 'torch.nn._functions.thnn.auto.Square'>
    SquareBackward                           <class 'torch.nn._functions.thnn.auto.SquareBackward'>
    ReplicationPad2d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>
    ReplicationPad2dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>
    L1Loss                                   <class 'torch.nn._functions.thnn.auto.L1Loss'>
    L1LossBackward                           <class 'torch.nn._functions.thnn.auto.L1LossBackward'>
    SpatialGridSamplerBilinear               <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>
    SpatialGridSamplerBilinearBackward       <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>
    Sqrt                                     <class 'torch.nn._functions.thnn.auto.Sqrt'>
    SqrtBackward                             <class 'torch.nn._functions.thnn.auto.SqrtBackward'>
    TemporalRowConvolution                   <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>
    TemporalRowConvolutionBackward           <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>
    SpatialFractionalMaxPooling              <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>
    SpatialFractionalMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>
    TemporalUpSamplingLinear                 <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>
    TemporalUpSamplingLinearBackward         <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>
    VolumetricDilatedMaxPooling              <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>
    VolumetricDilatedMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>
    Threshold                                <class 'torch.nn._functions.thnn.auto.Threshold'>
    ThresholdBackward                        <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>
    Abs                                      <class 'torch.nn._functions.thnn.auto.Abs'>
    AbsBackward                              <class 'torch.nn._functions.thnn.auto.AbsBackward'>
    Softshrink                               <class 'torch.nn._functions.thnn.auto.Softshrink'>
    SoftshrinkBackward                       <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>
    LeakyReLU                                <class 'torch.nn._functions.thnn.auto.LeakyReLU'>
    LeakyReLUBackward                        <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>
    VolumetricUpSamplingNearest              <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>
    VolumetricUpSamplingNearestBackward      <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>
    VolumetricDilatedConvolution             <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>
    VolumetricDilatedConvolutionBackward     <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>
    Tanh                                     <class 'torch.nn._functions.thnn.auto.Tanh'>
    TanhBackward                             <class 'torch.nn._functions.thnn.auto.TanhBackward'>
    TemporalSubSampling                      <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>
    TemporalSubSamplingBackward              <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>
    ELU                                      <class 'torch.nn._functions.thnn.auto.ELU'>
    ELUBackward                              <class 'torch.nn._functions.thnn.auto.ELUBackward'>
    Hardtanh                                 <class 'torch.nn._functions.thnn.auto.Hardtanh'>
    HardtanhBackward                         <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>
    L1Cost                                   <class 'torch.nn._functions.thnn.auto.L1Cost'>
    L1CostBackward                           <class 'torch.nn._functions.thnn.auto.L1CostBackward'>
    SpatialSubSampling                       <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>
    SpatialSubSamplingBackward               <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>
    Im2Col                                   <class 'torch.nn._functions.thnn.auto.Im2Col'>
    Im2ColBackward                           <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>
    KLDivLoss                                <class 'torch.nn._functions.thnn.auto.KLDivLoss'>
    KLDivLossBackward                        <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>
    SmoothL1Loss                             <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>
    SmoothL1LossBackward                     <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>
    ReflectionPad2d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>
    ReflectionPad2dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>
    CrossMapLRN2d                            <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>
    EmbeddingBag                             <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>
    

    一不留神把pytorch支持的所有预定义模块都给展示出来了。本文稍后开始讲解这些预定义模块的实现。

    其他有序字典

            self._parameters = OrderedDict() # 模块网络参数
            self._buffers = OrderedDict()       # 驻留内存(不释放,不交换)
            self._backward_hooks = OrderedDict() # 反向钩子函数字典,
            self._forward_hooks = OrderedDict() # 正向钩子函数字典
            self._forward_pre_hooks = OrderedDict() # 正向调用前钩子函数字典
            self._modules = OrderedDict() # 模块列表
            self.training = True # 训练还是验证
    

    模块函数

    模块的函数根据名称可以知道其作用,此处仅仅列举,不在详述

    名称 作用
    forward 前向计算虚函数
    register_buffer 注册驻留内存
    register_parameter 注册参数
    add_module 添加模块
    _apply 针对所有参数的操作
    apply 针对所有子模块的操作
    cuda 搬家到GPU上
    cpu 搬家到CPU上
    type 所有参数换类型喽
    float 统统换成浮点
    double 统统换成双精度浮点
    half 统统换成字(俩字节)
    to 给用户一个换类型和CGPU的接口,其实还是调用_
    register_backward_hook 注册反向钩子
    register_forward_pre_hook 注册前向调用前钩子
    register_forward_hook 注册前向钩子
    _slow_forward 没有加速的前向函数
    call 给个参数就执行的前向调用
    setstate 快速设置所有字典状态
    getattr 获取属性
    setattr 设置属性
    delattr 删除属性
    state_dict 当前状态字典的输出
    _load_from_state_dict 从状态字典中装载的执行函数
    load_state_dict 装载状态的用户接口
    children 子模块
    modules 所有模块
    train 训练
    eval 评估
    zero_grad 参数梯度清零
    share_memory 使用共享内存
    repr 迭代器
    dir 列举

    相关文章

      网友评论

        本文标题:6 Module -庖丁解牛之pytorch

        本文链接:https://www.haomeiwen.com/subject/ftvzzftx.html