美文网首页
[mmaction2版本] 视频分类(一) TSM:Tempor

[mmaction2版本] 视频分类(一) TSM:Tempor

作者: blackmanba_084b | 来源:发表于2022-05-27 17:23 被阅读0次

    前段时间研究了一波关于视频分类的相关代码,这里准备写一系列博客介绍一下有关于视频分类的相关论文及代码。这篇博客主要介绍下Temporal Shift Module for Efficient Video Understanding这篇文章。
    paper: TSM: Temporal Shift Module for Efficient Video Understanding
    code:① temporal-shift-module (这里是官方的代码推荐) ②[mmaction2](https://github.com/open-mmlab/mmaction2)(本文介绍代码地址,因为有很多注册钩子类的python文件初学者理解起来有一定难度)

    TSM作为视频分类之前之前主流视频分类还是是用的是3D CNN, 但是文章也说了计算量很大,不太适用于做在线视频分类,所以针对该问题,作者发明了TSM模型。该模型分为两个版本一个是在线视频分类,一个是离线视频分类,分别用来进行模型测试以及模型训练用的。

    一、 模型原理


    1.1 shift 机制

    该论文的核心就是如下图所示的,基于在时间维度上的通道移位来增加该端特征时间维度的信息包含。


    首先我们先理解下该图的表达,纵轴代表的是Temporal T即时间维度,这里我们选用的是抽帧的方式(一般选取8帧或16帧),横坐标代表的是Channel通道的维度,深度坐标代表的是特征的空间信息即(宽x高

    通过上下移动时间维度的channel信息,可以将时间感受野扩展,图中空白的区域采用的是0填充。

    那我们究竟选多少通道数进行移位呢?这边作者是这么说的。


    也就是说有前\frac{1}{8}的通道是上一帧的通道特征,后\frac{1}{8}的通道信息被下一帧通道移位,中间通道信息保持不变。这里选用\frac{1}{8}的原因作者在论文中也阐明了原因

    主要还是为了保持空间信息和时间信息的平衡,所以才选了\frac{1}{8}。即便shift操作对于模型来说的虽然增加0 FLOP,但是在内存消耗上还是偏大的,同时shift过多的通道数,损失空间信息,不利于模型表现。下面实验说明shift全部通道效果, 严重影响推理效率:

    1.2 Residual shift 机制


    发现Residual TSM融合了时间信息,效果好于In-place TSMIn-place损失了空间特征学习的能力。

    1.3 整体模型机理


    通过上图就很容易理解模型在对视频分类的原理了。首先通过对每一帧进行上述的shift操作,在进行卷积块操作即可(后面代码会清晰梳理原理),这里需要注意的是最终输出我们采用的是全局平均池化,得到特征在经过fc(fully connected)层输出模型类别的概率矩阵。

    二、代码理解

    这里的代码主要通过mmaction2代码框架进行行为分类。这里主要通过三部分来介绍TSM模型训练代码分别是:①数据处理 ②模型结构 ③损失计算

    2.1 数据处理

    先看下我们的数据的配置文件

    _base_ = [
        '../../_base_/models/tsm_r50.py', '../../_base_/schedules/sgd_tsm_50e.py',
        '../../_base_/default_runtime.py'
    ]
    
    dataset_type = 'VideoDataset'
    data_root = ""
    data_root_val = ""
    ann_file_train = '/data/humaocheng/action_classification/ava_datasets/train_data.txt'
    ann_file_val = '/data/humaocheng/action_classification/ava_datasets/test_data.txt'
    ann_file_test = '/data/humaocheng/action_classification/ava_datasets/test_data.txt'
    
    img_norm_cfg = dict(
        mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
    
    train_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, 256)),
        dict(
            type='MultiScaleCrop',
            input_size=224,
            scales=(1, 0.875, 0.75, 0.66),
            random_crop=False,
            max_wh_scale_gap=1),
        dict(type='Resize', scale=(224, 224), keep_ratio=False),
        dict(type='Flip', flip_ratio=0.5),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs', 'label'])
    ]
    
    
    val_pipeline = [
        dict(type='DecordInit'),
        dict(
            type='SampleFrames',
            clip_len=1,
            frame_interval=1,
            num_clips=8,
            test_mode=True),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, 256)),
        dict(type='CenterCrop', crop_size=224),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs'])
    ]
    test_pipeline = [
        dict(type='DecordInit'),
        dict(
            type='SampleFrames',
            clip_len=1,
            frame_interval=1,
            num_clips=8,
            test_mode=True),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, 256)),
        dict(type='CenterCrop', crop_size=224),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs'])
    ]
    data = dict(
        videos_per_gpu=8,
        workers_per_gpu=2,
        test_dataloader=dict(videos_per_gpu=1),
        train=dict(
            type=dataset_type,
            ann_file=ann_file_train,
            data_prefix=data_root,
            pipeline=train_pipeline),
        val=dict(
            type=dataset_type,
            ann_file=ann_file_val,
            data_prefix=data_root_val,
            pipeline=val_pipeline),
        test=dict(
            type=dataset_type,
            ann_file=ann_file_test,
            data_prefix=data_root_val,
            pipeline=test_pipeline))
    evaluation = dict(
        interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])
    
    # runtime settings
    checkpoint_config = dict(interval=5)
    work_dir = './work_dirs/tsm_r50_1x1x8_100e_kinetics400_rgb/'
    

    注意在训练自己的数据集的时候需要将下面的num_classes的类比数量进行修改

    # rsm_r50.py
    
    # model settings
    model = dict(
        type='Recognizer2D',
        backbone=dict(
            type='ResNetTSM',
            pretrained='torchvision://resnet50',
            depth=50,
            norm_eval=False,
            shift_div=8),
        cls_head=dict(
            type='TSMHead',
            num_classes=7,
            in_channels=2048,
            spatial_type='avg',
            consensus=dict(type='AvgConsensus', dim=1),
            dropout_ratio=0.5,
            init_std=0.001,
            is_shift=True),
        # model training and testing settings
        train_cfg=None,
        test_cfg=dict(average_clips='prob'))
    
    # optimizer
    optimizer = dict(
        type='SGD',
        lr=0.01,  # this lr is used for 8 gpus
        momentum=0.9,
        weight_decay=0.0001)
    optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
    # learning policy
    lr_config = dict(policy='step', step=[20, 40])
    total_epochs = 50
    

    首先要我们看train_pipeline是怎么处理视频数据集的

    train_pipeline = [
        dict(type='DecordInit'),
        dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),
        dict(type='DecordDecode'),
        dict(type='Resize', scale=(-1, 256)),
        dict(
            type='MultiScaleCrop',
            input_size=224,
            scales=(1, 0.875, 0.75, 0.66),
            random_crop=False,
            max_wh_scale_gap=1),
        dict(type='Resize', scale=(224, 224), keep_ratio=False),
        dict(type='Flip', flip_ratio=0.5),
        dict(type='Normalize', **img_norm_cfg),
        dict(type='FormatShape', input_format='NCHW'),
        dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]),
        dict(type='ToTensor', keys=['imgs', 'label'])
    ]
    
    1. DecordInit, DecordDecode

    这里的DecordInit, DecordDecode主要是用来做视频读写,这里用的对视频进行处理涉及到pyi代码(涉及C++ 加速视频读等处理)

    @PIPELINES.register_module()
    class DecordInit:
        """Using decord to initialize the video_reader.
    
        Decord: https://github.com/dmlc/decord
    
        Required keys are "filename",
        added or modified keys are "video_reader" and "total_frames".
    
        Args:
            io_backend (str): io backend where frames are store.
                Default: 'disk'.
            num_threads (int): Number of thread to decode the video. Default: 1.
            kwargs (dict): Args for file client.
        """
    
        def __init__(self, io_backend='disk', num_threads=1, **kwargs):
            self.io_backend = io_backend
            self.num_threads = num_threads
            self.kwargs = kwargs
            self.file_client = None
    
        def __call__(self, results):
            """Perform the Decord initialization.
    
            Args:
                results (dict): The resulting dict to be modified and passed
                    to the next transform in pipeline.
            """
            try:
                import decord
            except ImportError:
                raise ImportError(
                    'Please run "pip install decord" to install Decord first.')
    
            if self.file_client is None:
                self.file_client = FileClient(self.io_backend, **self.kwargs)
    
            file_obj = io.BytesIO(self.file_client.get(results['filename']))
            container = decord.VideoReader(file_obj, num_threads=self.num_threads)
            results['video_reader'] = container
            results['total_frames'] = len(container)
            return results
    
        def __repr__(self):
            repr_str = (f'{self.__class__.__name__}('
                        f'io_backend={self.io_backend}, '
                        f'num_threads={self.num_threads})')
            return repr_str
    
    @PIPELINES.register_module()
    class DecordDecode:
        """Using decord to decode the video.
    
        Decord: https://github.com/dmlc/decord
    
        Required keys are "video_reader", "filename" and "frame_inds",
        added or modified keys are "imgs" and "original_shape".
    
        Args:
            mode (str): Decoding mode. Options are 'accurate' and 'efficient'.
                If set to 'accurate', it will decode videos into accurate frames.
                If set to 'efficient', it will adopt fast seeking but only return
                key frames, which may be duplicated and inaccurate, and more
                suitable for large scene-based video datasets. Default: 'accurate'.
        """
    
        def __init__(self, mode='accurate'):
            self.mode = mode
            assert mode in ['accurate', 'efficient']
    
        def __call__(self, results):
            """Perform the Decord decoding.
    
            Args:
                results (dict): The resulting dict to be modified and passed
                    to the next transform in pipeline.
            """
            container = results['video_reader']
    
            if results['frame_inds'].ndim != 1:
                results['frame_inds'] = np.squeeze(results['frame_inds'])
    
            frame_inds = results['frame_inds']
    
            if self.mode == 'accurate':
                imgs = container.get_batch(frame_inds).asnumpy()
                imgs = list(imgs)
            elif self.mode == 'efficient':
                # This mode is faster, however it always returns I-FRAME
                container.seek(0)
                imgs = list()
                for idx in frame_inds:
                    container.seek(idx)
                    frame = container.next()
                    imgs.append(frame.asnumpy())
    
            results['video_reader'] = None
            del container
    
            results['imgs'] = imgs
            results['original_shape'] = imgs[0].shape[:2]
            results['img_shape'] = imgs[0].shape[:2]
    
            return results
    
        def __repr__(self):
            repr_str = f'{self.__class__.__name__}(mode={self.mode})'
            return repr_str
    
    2. SampleFrames

    这里我们使用SampleFrames,参数为dict(type='SampleFrames', clip_len=1, frame_interval=1, num_clips=8),这里的意思相当于把一段视频等分8份,在每一份的视频中随机抽取帧,这样总共就能抽取8帧视频。

    @PIPELINES.register_module()
    class SampleFrames:
        """Sample frames from the video.
    
        Required keys are "total_frames", "start_index" , added or modified keys
        are "frame_inds", "frame_interval" and "num_clips".
    
        Args:
            clip_len (int): Frames of each sampled output clip.
            frame_interval (int): Temporal interval of adjacent sampled frames.
                Default: 1.
            num_clips (int): Number of clips to be sampled. Default: 1.
            temporal_jitter (bool): Whether to apply temporal jittering.
                Default: False.
            twice_sample (bool): Whether to use twice sample when testing.
                If set to True, it will sample frames with and without fixed shift,
                which is commonly used for testing in TSM model. Default: False.
            out_of_bound_opt (str): The way to deal with out of bounds frame
                indexes. Available options are 'loop', 'repeat_last'.
                Default: 'loop'.
            test_mode (bool): Store True when building test or validation dataset.
                Default: False.
            start_index (None): This argument is deprecated and moved to dataset
                class (``BaseDataset``, ``VideoDatset``, ``RawframeDataset``, etc),
                see this: https://github.com/open-mmlab/mmaction2/pull/89.
            keep_tail_frames (bool): Whether to keep tail frames when sampling.
                Default: False.
        """
    
        def __init__(self,
                     clip_len,
                     frame_interval=1,
                     num_clips=1,
                     temporal_jitter=False,
                     twice_sample=False,
                     out_of_bound_opt='loop',
                     test_mode=False,
                     start_index=None,
                     keep_tail_frames=False):
    
            self.clip_len = clip_len
            self.frame_interval = frame_interval
            self.num_clips = num_clips
            self.temporal_jitter = temporal_jitter
            self.twice_sample = twice_sample
            self.out_of_bound_opt = out_of_bound_opt
            self.test_mode = test_mode
            self.keep_tail_frames = keep_tail_frames
            assert self.out_of_bound_opt in ['loop', 'repeat_last']
    
            if start_index is not None:
                warnings.warn('No longer support "start_index" in "SampleFrames", '
                              'it should be set in dataset class, see this pr: '
                              'https://github.com/open-mmlab/mmaction2/pull/89')
    
        def _get_train_clips(self, num_frames):
            """Get clip offsets in train mode.
    
            It will calculate the average interval for selected frames,
            and randomly shift them within offsets between [0, avg_interval].
            If the total number of frames is smaller than clips num or origin
            frames length, it will return all zero indices.
    
            Args:
                num_frames (int): Total number of frame in the video.
    
            Returns:
                np.ndarray: Sampled frame indices in train mode.
            """
            ori_clip_len = self.clip_len * self.frame_interval
    
            if self.keep_tail_frames:
                avg_interval = (num_frames - ori_clip_len + 1) / float(
                    self.num_clips)
                if num_frames > ori_clip_len - 1:
                    base_offsets = np.arange(self.num_clips) * avg_interval
                    clip_offsets = (base_offsets + np.random.uniform(
                        0, avg_interval, self.num_clips)).astype(np.int)
                else:
                    clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
            else:
                avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips
    
                if avg_interval > 0:
                    base_offsets = np.arange(self.num_clips) * avg_interval
                    clip_offsets = base_offsets + np.random.randint(
                        avg_interval, size=self.num_clips)
                elif num_frames > max(self.num_clips, ori_clip_len):
                    clip_offsets = np.sort(
                        np.random.randint(
                            num_frames - ori_clip_len + 1, size=self.num_clips))
                elif avg_interval == 0:
                    ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips
                    clip_offsets = np.around(np.arange(self.num_clips) * ratio)
                else:
                    clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
    
            return clip_offsets
    
        def _get_test_clips(self, num_frames):
            """Get clip offsets in test mode.
    
            Calculate the average interval for selected frames, and shift them
            fixedly by avg_interval/2. If set twice_sample True, it will sample
            frames together without fixed shift. If the total number of frames is
            not enough, it will return all zero indices.
    
            Args:
                num_frames (int): Total number of frame in the video.
    
            Returns:
                np.ndarray: Sampled frame indices in test mode.
            """
            ori_clip_len = self.clip_len * self.frame_interval
            avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)
            if num_frames > ori_clip_len - 1:
                base_offsets = np.arange(self.num_clips) * avg_interval
                clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int)
                if self.twice_sample:
                    clip_offsets = np.concatenate([clip_offsets, base_offsets])
            else:
                clip_offsets = np.zeros((self.num_clips, ), dtype=np.int)
            return clip_offsets
    
        def _sample_clips(self, num_frames):
            """Choose clip offsets for the video in a given mode.
    
            Args:
                num_frames (int): Total number of frame in the video.
    
            Returns:
                np.ndarray: Sampled frame indices.
            """
            if self.test_mode:
                clip_offsets = self._get_test_clips(num_frames)
            else:
                clip_offsets = self._get_train_clips(num_frames)
    
            return clip_offsets
    
        def __call__(self, results):
            """Perform the SampleFrames loading.
    
            Args:
                results (dict): The resulting dict to be modified and passed
                    to the next transform in pipeline.
            """
            total_frames = results['total_frames']
    
            clip_offsets = self._sample_clips(total_frames)
            frame_inds = clip_offsets[:, None] + np.arange(
                self.clip_len)[None, :] * self.frame_interval
            frame_inds = np.concatenate(frame_inds)
    
            if self.temporal_jitter:
                perframe_offsets = np.random.randint(
                    self.frame_interval, size=len(frame_inds))
                frame_inds += perframe_offsets
    
            frame_inds = frame_inds.reshape((-1, self.clip_len))
            if self.out_of_bound_opt == 'loop':
                frame_inds = np.mod(frame_inds, total_frames)
            elif self.out_of_bound_opt == 'repeat_last':
                safe_inds = frame_inds < total_frames
                unsafe_inds = 1 - safe_inds
                last_ind = np.max(safe_inds * frame_inds, axis=1)
                new_inds = (safe_inds * frame_inds + (unsafe_inds.T * last_ind).T)
                frame_inds = new_inds
            else:
                raise ValueError('Illegal out_of_bound option.')
    
            start_index = results['start_index']
            frame_inds = np.concatenate(frame_inds) + start_index
            results['frame_inds'] = frame_inds.astype(np.int)
            results['clip_len'] = self.clip_len
            results['frame_interval'] = self.frame_interval
            results['num_clips'] = self.num_clips
            return results
    
        def __repr__(self):
            repr_str = (f'{self.__class__.__name__}('
                        f'clip_len={self.clip_len}, '
                        f'frame_interval={self.frame_interval}, '
                        f'num_clips={self.num_clips}, '
                        f'temporal_jitter={self.temporal_jitter}, '
                        f'twice_sample={self.twice_sample}, '
                        f'out_of_bound_opt={self.out_of_bound_opt}, '
                        f'test_mode={self.test_mode})')
            return repr_str
    
    3. Resize, MultiScaleCrop, Flip, Normalize,FormatShape,Collect,ToTensor

    这些比较浅显易懂的意思这里就不做过多的介绍了,需要强调的是Collect主要是将视频和标签进行数据集整理。最终我们会得到的结果为:
    imgs: shape [8, 8, 3, 224, 224], 这里的数字分别代表Batch Size, 帧长度通道数帧高帧宽
    labels: shape[8,1] ,这里的数字分别代表Batch Size, 类别

    2.2 模型结构

    2.2.1 Pipeline

    主要的pipeline代码如下:

    @RECOGNIZERS.register_module()
    class Recognizer2D(BaseRecognizer):
        """2D recognizer model framework."""
    
        def forward_train(self, imgs, labels, **kwargs):
            """Defines the computation performed at every call when training."""
    
            assert self.with_cls_head
            batches = imgs.shape[0]
            imgs = imgs.reshape((-1, ) + imgs.shape[2:])
            num_segs = imgs.shape[0] // batches
    
            losses = dict()
    
            x = self.extract_feat(imgs)
    
            if self.backbone_from in ['torchvision', 'timm']:
                if len(x.shape) == 4 and (x.shape[2] > 1 or x.shape[3] > 1):
                    # apply adaptive avg pooling
                    x = nn.AdaptiveAvgPool2d(1)(x)
                x = x.reshape((x.shape[0], -1))
                x = x.reshape(x.shape + (1, 1))
    
            if self.with_neck:
                x = [
                    each.reshape((-1, num_segs) +
                                 each.shape[1:]).transpose(1, 2).contiguous()
                    for each in x
                ]
                x, loss_aux = self.neck(x, labels.squeeze())
                x = x.squeeze(2)
                num_segs = 1
                losses.update(loss_aux)
    
            cls_score = self.cls_head(x, num_segs) # x shape [64, 2048, 7, 7]
            gt_labels = labels.squeeze()
            loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
            losses.update(loss_cls)
    
            return losses
    

    batches: shape 为 8
    imgs = imgs.reshape((-1, ) + imgs.shape[2:]),这里的img一开始为(8, 8, 3, 224, 224), 之后再将batch维度与frame维度进行合并最终img维度为(8, 8, 3, 224, 224)
    num_segs代表的是帧数
    extract_feat 代表的是进行特征提取

    2.2.2 特征提取

    接着会进入到特征提取模块

        def forward(self, x):
            """Defines the computation performed at every call.
    
            Args:
                x (torch.Tensor): The input data.
    
            Returns:
                torch.Tensor: The feature of the input samples extracted
                by the backbone.
            """
            # x shape [64, 3, 224, 224]
            x = self.conv1(x)
            # x shape [64, 64, 56, 56]
            x = self.maxpool(x)
            # x shape [64, 64, 56, 56]
            outs = []
            for i, layer_name in enumerate(self.res_layers):
                res_layer = getattr(self, layer_name)
                x = res_layer(x)
                if i in self.out_indices:
                    outs.append(x)
            if len(outs) == 1:
                return outs[0]
    
            return tuple(outs)
    

    在一开始会进入conv以及maxpool模块中去,之后会进入到四个res_layer模块中。关于这四个res_layer是如何组建的看如下代码

        def make_temporal_shift(self):
            """Make temporal shift for some layers."""
            if self.temporal_pool:
                num_segment_list = [
                    self.num_segments, self.num_segments // 2,
                    self.num_segments // 2, self.num_segments // 2
                ]
            else:
                num_segment_list = [self.num_segments] * 4
            if num_segment_list[-1] <= 0:
                raise ValueError('num_segment_list[-1] must be positive')
    
            if self.shift_place == 'block':
    
                def make_block_temporal(stage, num_segments):
                    """Make temporal shift on some blocks.
    
                    Args:
                        stage (nn.Module): Model layers to be shifted.
                        num_segments (int): Number of frame segments.
    
                    Returns:
                        nn.Module: The shifted blocks.
                    """
                    blocks = list(stage.children())
                    for i, b in enumerate(blocks):
                        blocks[i] = TemporalShift(
                            b, num_segments=num_segments, shift_div=self.shift_div)
                    return nn.Sequential(*blocks)
    
                self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
                self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
                self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
                self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
    
            elif 'blockres' in self.shift_place:
                n_round = 1
                if len(list(self.layer3.children())) >= 23:
                    n_round = 2
    
                def make_block_temporal(stage, num_segments):
                    """Make temporal shift on some blocks.
    
                    Args:
                        stage (nn.Module): Model layers to be shifted.
                        num_segments (int): Number of frame segments.
    
                    Returns:
                        nn.Module: The shifted blocks.
                    """
                    blocks = list(stage.children())
                    for i, b in enumerate(blocks):
                        if i % n_round == 0:
                            blocks[i].conv1.conv = TemporalShift(
                                b.conv1.conv,
                                num_segments=num_segments,
                                shift_div=self.shift_div)
                    return nn.Sequential(*blocks)
    
                self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
                self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
                self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
                self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
    
            else:
                raise NotImplementedError
    

    通过make_temporal_shift函数我们可以得到num_segment_list[8,8,8,8] 我们使用的是blokres

    1. 构建self.layer1, self.layer2, self.layer3, self.layer4
      首先根据make_block_temporal函数构建的,我们首先看下在输入之前self.layer1, self.layer2, self.layer3, self.layer4为残差卷积块构成。下面展示下其卷积内部结构。之后加入了TemporalShift机制再看下self.layer1, self.layer2, self.layer3, self.layer4
      layer1:
    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer2:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer3:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer4:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    
    1. TemporalShift

    这里的TemporalShift的逻辑代码如下所示:

    class TemporalShift(nn.Module):
        """Temporal shift module.
    
        This module is proposed in
        `TSM: Temporal Shift Module for Efficient Video Understanding
        <https://arxiv.org/abs/1811.08383>`_
    
        Args:
            net (nn.module): Module to make temporal shift.
            num_segments (int): Number of frame segments. Default: 3.
            shift_div (int): Number of divisions for shift. Default: 8.
        """
    
        def __init__(self, net, num_segments=3, shift_div=8):
            super().__init__()
            self.net = net
            self.num_segments = num_segments
            self.shift_div = shift_div
    
        def forward(self, x):
            """Defines the computation performed at every call.
    
            Args:
                x (torch.Tensor): The input data.
    
            Returns:
                torch.Tensor: The output of the module.
            """
            # self.num_segments = 8
            x = self.shift(x, self.num_segments, shift_div=self.shift_div)
            return self.net(x)
    
        @staticmethod
        def shift(x, num_segments, shift_div=3):
            """Perform temporal shift operation on the feature.
    
            Args:
                x (torch.Tensor): The input feature to be shifted.
                num_segments (int): Number of frame segments.
                shift_div (int): Number of divisions for shift. Default: 3.
    
            Returns:
                torch.Tensor: The shifted feature.
            """
            # [N, C, H, W]
            n, c, h, w = x.size()
            # n=64, h=56, w=56, c=64
    
            # [N // num_segments, num_segments, C, H*W]
            # can't use 5 dimensional array on PPL2D backend for caffe
            x = x.view(-1, num_segments, c, h * w) # x shape [8, 8, 64, 3136]
    
            # get shift fold
            fold = c // shift_div
    
            # split c channel into three parts:
            # left_split, mid_split, right_split
            left_split = x[:, :, :fold, :] # shape[8, 8, 8, 3136]
            mid_split = x[:, :, fold:2 * fold, :] # shape[8, 8, 8, 3136]
            right_split = x[:, :, 2 * fold:, :] # shape [8, 8, 48, 3136]
    
            # can't use torch.zeros(*A.shape) or torch.zeros_like(A)
            # because array on caffe inference must be got by computing
    
            #  SHAPE [BATCH_SIZE, CHANNEL, TIME, HEIGHT*WIDTH]
    
            # shift left on num_segments channel in `left_split`
            zeros = left_split - left_split
            blank = zeros[:, :1, :, :] # shape [8, 1, 8, 3136]
            left_split = left_split[:, 1:, :, :] # shape [8, 7, 8, 3136]
            left_split = torch.cat((left_split, blank), 1)
    
            # shift right on num_segments channel in `mid_split`
            zeros = mid_split - mid_split
            blank = zeros[:, :1, :, :]
            mid_split = mid_split[:, :-1, :, :]
            mid_split = torch.cat((blank, mid_split), 1)
    
            # right_split: no shift
    
            # concatenate
            out = torch.cat((left_split, mid_split, right_split), 2)
    
            # [N, C, H, W]
            # restore the original dimension
            return out.view(n, c, h, w)
    

    上述代码则是该模型的核心。

    x = x.view(-1, num_segments, c, h * w), 则是将我们的长宽进行合并,这样就得到了其shape为[8, 8, 64, 3136]

    fold = c // shift_div
    这里的c表示的, shift_div8, 得到fold为8。这里就是我们之前说的,将64个通道数分为8份,每份通道数为8,分后移部分1份(left_split), 前移部分1份(mid_split), 不变的6份(right_split)最后将其都concate。

    left_split = x[:, :, :fold, :] # shape[8, 8, 8, 3136]
    mid_split = x[:, :, fold:2 * fold, :] # shape[8, 8, 8, 3136]
    right_split = x[:, :, 2 * fold:, :] # shape [8, 8, 48, 3136]
    
    # shift left on num_segments channel in `left_split`
    zeros = left_split - left_split
    blank = zeros[:, :1, :, :] # shape [8, 1, 8, 3136]
    left_split = left_split[:, 1:, :, :] # shape [8, 7, 8, 3136]
    left_split = torch.cat((left_split, blank), 1)
    
    # shift right on num_segments channel in `mid_split`
    zeros = mid_split - mid_split
    blank = zeros[:, :1, :, :]
    mid_split = mid_split[:, :-1, :, :]
    mid_split = torch.cat((blank, mid_split), 1)
    # right_split: no shift
    # concatenate
    out = torch.cat((left_split, mid_split, right_split), 2)
    
    1. TemporalShift+resConv我们来看看模型layer1, layer2, layer3, layer4的样子
      layer1:
    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer2:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer3:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    

    layer4:

    Sequential(
      (0): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
        (downsample): ConvModule(
          (conv): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): ConvModule(
          (conv): TemporalShift(
            (net): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv2): ConvModule(
          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activate): ReLU(inplace=True)
        )
        (conv3): ConvModule(
          (conv): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (relu): ReLU(inplace=True)
      )
    )
    
    1. 输出
      最终输出矩阵特征shape大小为[64, 2048, 7, 7]
    2.2.3 head

    这里的head很简单就是类似于fc层,将2048维度直接降到对应的类别数目, 为了后期做损失函数的计算, 但是这里需要注意的是我们在输如到head的特征shape为[N * num_segs, in_channels, 7, 7](N为8, num_segs为8),
    因此我们在这里是对7x7的空间特征做了空间的卷据平均池化使其shape为[N * num_segs, in_channels, 1, 1], 并通过flatten使其shape为[64, 2048], 接着输入到fc得到shape为[64, classes],然后分类batch所以最终shape为[N, classes ]具体代码如下所示。

    
    @HEADS.register_module()
    class TSMHead(BaseHead):
        """Class head for TSM.
    
        Args:
            num_classes (int): Number of classes to be classified.
            in_channels (int): Number of channels in input feature.
            num_segments (int): Number of frame segments. Default: 8.
            loss_cls (dict): Config for building loss.
                Default: dict(type='CrossEntropyLoss')
            spatial_type (str): Pooling type in spatial dimension. Default: 'avg'.
            consensus (dict): Consensus config dict.
            dropout_ratio (float): Probability of dropout layer. Default: 0.4.
            init_std (float): Std value for Initiation. Default: 0.01.
            is_shift (bool): Indicating whether the feature is shifted.
                Default: True.
            temporal_pool (bool): Indicating whether feature is temporal pooled.
                Default: False.
            kwargs (dict, optional): Any keyword argument to be used to initialize
                the head.
        """
    
        def __init__(self,
                     num_classes,
                     in_channels,
                     num_segments=8,
                     loss_cls=dict(type='CrossEntropyLoss'),
                     spatial_type='avg',
                     consensus=dict(type='AvgConsensus', dim=1),
                     dropout_ratio=0.8,
                     init_std=0.001,
                     is_shift=True,
                     temporal_pool=False,
                     **kwargs):
            super().__init__(num_classes, in_channels, loss_cls, **kwargs)
    
            self.spatial_type = spatial_type
            self.dropout_ratio = dropout_ratio
            self.num_segments = num_segments
            self.init_std = init_std
            self.is_shift = is_shift
            self.temporal_pool = temporal_pool
    
            consensus_ = consensus.copy()
    
            consensus_type = consensus_.pop('type')
            if consensus_type == 'AvgConsensus':
                self.consensus = AvgConsensus(**consensus_)
            else:
                self.consensus = None
    
            if self.dropout_ratio != 0:
                self.dropout = nn.Dropout(p=self.dropout_ratio)
            else:
                self.dropout = None
            self.fc_cls = nn.Linear(self.in_channels, self.num_classes)
    
            if self.spatial_type == 'avg':
                # use `nn.AdaptiveAvgPool2d` to adaptively match the in_channels.
                self.avg_pool = nn.AdaptiveAvgPool2d(1)
            else:
                self.avg_pool = None
    
        def init_weights(self):
            """Initiate the parameters from scratch."""
            normal_init(self.fc_cls, std=self.init_std)
    
        def forward(self, x, num_segs):
            """Defines the computation performed at every call.
    
            Args:
                x (torch.Tensor): The input data.
                num_segs (int): Useless in TSMHead. By default, `num_segs`
                    is equal to `clip_len * num_clips * num_crops`, which is
                    automatically generated in Recognizer forward phase and
                    useless in TSM models. The `self.num_segments` we need is a
                    hyper parameter to build TSM models.
            Returns:
                torch.Tensor: The classification scores for input samples.
            """
            # [N * num_segs, in_channels, 7, 7]
            if self.avg_pool is not None: # x shape [64, 2048, 7, 7]
                x = self.avg_pool(x) # 全局平均池化 空间为1x1, 通道不变
                                    # x shape [64, 2048, 1, 1]
            # [N * num_segs, in_channels, 1, 1]
            x = torch.flatten(x, 1) # x shape [64, 2048]
            # [N * num_segs, in_channels]
            if self.dropout is not None:
                x = self.dropout(x)
            # [N * num_segs, num_classes]
            cls_score = self.fc_cls(x) # cls shape [64, 400]
    
            if self.is_shift and self.temporal_pool:
                # [2 * N, num_segs // 2, num_classes]
                cls_score = cls_score.view((-1, self.num_segments // 2) +
                                           cls_score.size()[1:])
            else:
                # [N, num_segs, num_classes]
                cls_score = cls_score.view((-1, self.num_segments) +
                                           cls_score.size()[1:]) # batch 分离
            # cls shape [8, 8, 400]
            # [N, 1, num_classes]
            cls_score = self.consensus(cls_score)
            # [N, num_classes]
            return cls_score.squeeze(1)
    

    2.3 损失函数

    这里的损失函数文章使用的是CrossEntropy损失函数。

    感悟:

    这里的模型可以看出越深层的卷积包含更多前后帧的信息(动作语义信息),越浅层的包含越少帧的信息(空间信息), 想想模型涉及的有那么点意思~

    相关文章

      网友评论

          本文标题:[mmaction2版本] 视频分类(一) TSM:Tempor

          本文链接:https://www.haomeiwen.com/subject/vyjuprtx.html