美文网首页
torchvision.models.segmentation.

torchvision.models.segmentation.

作者: blair_liu | 来源:发表于2021-03-11 20:35 被阅读0次

    随便一个位置

    from torchvision.models.segmentation.segmentation import fcn_resnet50
    

    跳转到fcn_resnet50

    def fcn_resnet50(pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs):
        """
        :param pretrained: 是否下载预训练权重
        :param progress: 是否显示下载进度条
        :param num_classes: 类别数
        :param aux_loss:是否有辅助损失
        :param kwargs:额外参数
        :return:fcn_resnet50模型
        """
        """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
    
        Args:
            pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
                contains the same classes as Pascal VOC
            progress (bool): If True, displays a progress bar of the download to stderr
        """
        return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs)
    

    _load_model加载模型

    def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
        """
        :param arch_type: 模型名称 'fcn'
        :param backbone: 模型主干 resnet50
        :param pretrained: 是否下载预训练权重
        :param progress: 是否显示下载进度条
        :param num_classes: 类别数
        :param aux_loss: 是否有辅助损失
        :param kwargs: 额外参数
        :return: fcn模型
        """
        if pretrained:  # 如果下载预训练权重,就有辅助损失
            aux_loss = True
        model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)  # 获取分割模型,具体见下一段_segm_resnet
        if pretrained:
            arch = arch_type + '_' + backbone + '_coco'
            model_url = model_urls[arch]
            if model_url is None:  # 如果没找到预训练权重,就报错
                raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
            else:
                state_dict = load_state_dict_from_url(model_url, progress=progress)  # 下载预训练权重
                model.load_state_dict(state_dict)  # 模型加载预训练权重
        return model
    

    _segm_resnet

    def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
        """
        :param name: 模型名称 'fcn'
        :param backbone_name: 模型主干 resnet50
        :param num_classes: 类别数
        :param aux: 是否有辅助损失
        :param pretrained_backbone: 是否有模型主干预训练权重
        :return: fcn模型
        """
        backbone = resnet.__dict__[backbone_name](
            pretrained=pretrained_backbone,
            replace_stride_with_dilation=[False, True, True])
        """
            resnet.__dict__包含了resnet所有变量,函数和类,下面代码自行验证
            from torchvision.models import resnet
            for key, value in resnet.__dict__.items():
                print(key, value)
                print('-'*50)
            此处等效于:
            backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])
        """
        return_layers = {'layer4': 'out'}  # B * 2048 * 7 * 7
        if aux:
            return_layers['layer3'] = 'aux'  # B * 1014 * 14 * 14
        # 将resnet50裁剪成我们需要的
        backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)  # 中间层获取函数
        # IntermediateLayerGetter 看它的介绍里面的Examples就很清楚了
        # 返回的是一个字典 backbone输出{'out':layer4输出, 'aux':layer3输出}
    
        aux_classifier = None
        if aux:
            inplanes = 1024
            aux_classifier = FCNHead(inplanes, num_classes)  # 说是头有点不合理,其实是接在resnet50主干网络后面,可以叫主体
            # FCNHead主要作用是将主干网络的输出变为类别数
    
        model_map = {  # 这个model_map的主要目的就是适应不同的模型
            'deeplabv3': (DeepLabHead, DeepLabV3),
            'fcn': (FCNHead, FCN),
        }
        inplanes = 2048
        classifier = model_map[name][0](inplanes, num_classes)  # 等效于:FCNHead(inplanes, num_classes)
        base_model = model_map[name][1]  # 等效于FCN,其实FCN是_SimpleSegmentationModel
    
        model = base_model(backbone, classifier, aux_classifier)  # model = FCN(backbone, classifier, aux_classifier)
        return model
    

    FCNHead

    class FCNHead(nn.Sequential):  # FCN主体
        def __init__(self, in_channels, channels):
            inter_channels = in_channels // 4
            layers = [
                nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                nn.BatchNorm2d(inter_channels),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Conv2d(inter_channels, channels, 1)
            ]
    
            super(FCNHead, self).__init__(*layers)
    

    FCN即_SimpleSegmentationModel

    class _SimpleSegmentationModel(nn.Module):
        __constants__ = ['aux_classifier']
    
        def __init__(self, backbone, classifier, aux_classifier=None):
            super(_SimpleSegmentationModel, self).__init__()
            self.backbone = backbone  # 主干网络输出
            self.classifier = classifier  # 分割网络输出
            self.aux_classifier = aux_classifier  # 辅助分割网络输出
    
        def forward(self, x):
            input_shape = x.shape[-2:]  # H W
            # contract: features is a dict of tensors
            features = self.backbone(x)  # 输出{'out':layer4输出, 'aux':layer3输出}
    
            result = OrderedDict()
            x = features["out"]
            x = self.classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["out"] = x
    
            if self.aux_classifier is not None:
                x = features["aux"]
                x = self.aux_classifier(x)
                x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
                result["aux"] = x
    
            return result  # 字典
    

    相关文章

      网友评论

          本文标题:torchvision.models.segmentation.

          本文链接:https://www.haomeiwen.com/subject/zrhoqltx.html