美文网首页
Transformer在图像中的运用(四)DETR(DEtect

Transformer在图像中的运用(四)DETR(DEtect

作者: blackmanba_084b | 来源:发表于2022-01-13 17:59 被阅读0次

    Transformer最后一个小节来说下基于Transformer进行目标检测。
    这里还是先贴出论文及代码地址
    paper: End-to-End Object Detection with Transformers
    code: facebookresearch/detr

    一 、原理

    1.1 模型结构介绍

    关于DETR的原理非常的简单, 下面的图就展示了DETR的过程。

    DETR原理图
    上述的图我们可以看出首先将输入的图片再经过我们的CNN得到了一系列Patch, 再将这些Patch输入到transformer做编码解码的任务。这里的编码的过程和我们之前说的VIT一样(Transformer 在图像中的运用(一)VIT(Transformers for Image Recognition at Scale)论文及代码解读),这里主要的区别还是再解码的过程中,我们对于其检测是直接预测100个坐标框(包含前景以及背景)。

    1.2 解码器介绍

    细节原理图
    对于上述这张图可以看出来我们的encoder提供的是矩阵K以及V, 我们的transformer decoder提供的是Q。这里的Q如下红色框所示,可以理解为就是针对不同的特征的一种特征搜索器,形象的解释就是各自有各自的提取特征的任务。

    下面我们可以看出基于Encoder完成的任务效果



    下面我们通过具体的网络结构层能更清晰的了解我们的网络结构。


    • 需要注意的是这里解码器初始化object queries为(0+位置编码)
    • 需要注意的是解码器一开始是Self-Attention第二次才是Attention,可以形象理解第一个Attention相当于让模型一开始内部先内部确定各自的任务,每一个该提取哪些特征, 第二次我们只要给一个Q, 再与EncoderK以及V

    1.3 损失函数介绍

    最后一点关于LOSS是如何计算的呢,比如说下面这张图,GroudTruth只有两个,但是要预测恒为100个,我们应该如何比配呢?这里我们采用匈牙利匹配来完成,按照LOSS最小的组合,剩下的都是为背景。


    还有一点给就是再做decoder的时候由很多层,我们可以不同层都进行损失计算,这样效果更好,具体后面代码介绍会详细介绍。

    1.4 效果图


    可以看出专门选择了相互遮挡的两个物体,我们主要关注图中注意力的颜色,可以看出注意力还是很准的。

    二 、代码解读

    2.1 DETR

    class DETR(nn.Module):
        """ This is the DETR module that performs object detection """
        def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
            """ Initializes the model.
            Parameters:
                backbone: torch module of the backbone to be used. See backbone.py
                transformer: torch module of the transformer architecture. See transformer.py
                num_classes: number of object classes
                num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                             DETR can detect in a single image. For COCO, we recommend 100 queries.
                aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
            """
            super().__init__()
            self.num_queries = num_queries
            self.transformer = transformer
            hidden_dim = transformer.d_model
            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
            self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
            self.query_embed = nn.Embedding(num_queries, hidden_dim)
            self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
            self.backbone = backbone
            self.aux_loss = aux_loss
    
        def forward(self, samples: NestedTensor):
            """ The forward expects a NestedTensor, which consists of:
                   - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
                   - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
    
                It returns a dict with the following elements:
                   - "pred_logits": the classification logits (including no-object) for all queries.
                                    Shape= [batch_size x num_queries x (num_classes + 1)]
                   - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                                   (center_x, center_y, height, width). These values are normalized in [0, 1],
                                   relative to the size of each individual image (disregarding possible padding).
                                   See PostProcess for information on how to retrieve the unnormalized bounding box.
                   - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                    dictionnaries containing the two above keys for each decoder layer.
            """
            if isinstance(samples, (list, torch.Tensor)):
                samples = nested_tensor_from_tensor_list(samples)
            features, pos = self.backbone(samples)
    
            src, mask = features[-1].decompose()
            assert mask is not None
            hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
    
            outputs_class = self.class_embed(hs)
            outputs_coord = self.bbox_embed(hs).sigmoid()
            out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
            if self.aux_loss:
                out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
            return out
    

    首先通过features, pos = self.backbone(samples) 来得到resnet50的特征以及我么的位置信息, 如何获得我们的位置信息,可以看到下面的代码进行了阐述。

    def build_backbone(args):
        position_embedding = build_position_encoding(args)
        train_backbone = args.lr_backbone > 0
        return_interm_layers = args.masks
        backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
        model = Joiner(backbone, position_embedding)
        model.num_channels = backbone.num_channels
        return model
    

    2.2 位置编码

    def build_position_encoding(args):
        N_steps = args.hidden_dim // 2
        if args.position_embedding in ('v2', 'sine'):
            # TODO find a better way of exposing other arguments
            position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
        elif args.position_embedding in ('v3', 'learned'):
            position_embedding = PositionEmbeddingLearned(N_steps)
        else:
            raise ValueError(f"not supported {args.position_embedding}")
    
        return position_embedding
    

    下面也很好理解我们设置一个embeding向量去学习

    self.row_embed = nn.Embedding(50, num_pos_feats)
    self.col_embed = nn.Embedding(50, num_pos_feats)
    

    1. 位置编码方式一

    class PositionEmbeddingLearned(nn.Module):
        """
        Absolute pos embedding, learned.
        """
        def __init__(self, num_pos_feats=256):
            super().__init__()
            self.row_embed = nn.Embedding(50, num_pos_feats)
            self.col_embed = nn.Embedding(50, num_pos_feats)
            self.reset_parameters()
    
        def reset_parameters(self):
            nn.init.uniform_(self.row_embed.weight)
            nn.init.uniform_(self.col_embed.weight)
    
        def forward(self, tensor_list: NestedTensor):
            x = tensor_list.tensors
            h, w = x.shape[-2:]
            i = torch.arange(w, device=x.device)
            j = torch.arange(h, device=x.device)
            x_emb = self.col_embed(i)
            y_emb = self.row_embed(j)
            pos = torch.cat([
                x_emb.unsqueeze(0).repeat(h, 1, 1),
                y_emb.unsqueeze(1).repeat(1, w, 1),
            ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
            return pos
    
    

    2. 位置编码方式二 (推荐)

    class PositionEmbeddingSine(nn.Module):
        """
        This is a more standard version of the position embedding, very similar to the one
        used by the Attention is all you need paper, generalized to work on images.
        """
        def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
            super().__init__()
            self.num_pos_feats = num_pos_feats
            self.temperature = temperature
            self.normalize = normalize
            if scale is not None and normalize is False:
                raise ValueError("normalize should be True if scale is passed")
            if scale is None:
                scale = 2 * math.pi
            self.scale = scale
    
        def forward(self, tensor_list: NestedTensor):
            x = tensor_list.tensors
            mask = tensor_list.mask
            assert mask is not None
            not_mask = ~mask
            y_embed = not_mask.cumsum(1, dtype=torch.float32) # 行方向累加
            x_embed = not_mask.cumsum(2, dtype=torch.float32) # 列方向累加
            if self.normalize:
                eps = 1e-6
                y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
                x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
    
            dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
            dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
    
            pos_x = x_embed[:, :, :, None] / dim_t
            pos_y = y_embed[:, :, :, None] / dim_t
            pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
            pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
            pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
            return pos
    

    上述这里我们使用cosine函数不同频率进行编码详细可参考BERT(一) Transformer原理理解, 公式如下所示:


    我们来看下forward函数,我们的输入特征图x的shape为(2, 2048, 24, 29), 分别对应batch(后面的2均为batch_size,不再提了), channel, h, w
    这里的mask是由True以及False构成,其shape大小为2, 24, 29分别对应batch, h, w。 这里的mask如果是True,代表该区域是padding, 如果是False则不是,为什么要有padding, 这是因为为了组成相同大小而生成一个batch,有些图像预处理是需要加padding的, 以便后续模型能够在有效区域内学习目标,相当于加入了一部分先验知识。接着进行行列方向的累加,有些数值是重复的可以忽略,因为padding 是True。 下面代码可以理解如何对padding进行处理, 后经过resize映射到和特征图一样的大小:
    # file_path: detr/utils/misc.py
    
    def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
        # TODO make this more general
        if tensor_list[0].ndim == 3:
            if torchvision._is_tracing():
                # nested_tensor_from_tensor_list() does not export well to ONNX
                # call _onnx_nested_tensor_from_tensor_list() instead
                return _onnx_nested_tensor_from_tensor_list(tensor_list)
    
            # TODO make it support different-sized images
            max_size = _max_by_axis([list(img.shape) for img in tensor_list])
            # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
            batch_shape = [len(tensor_list)] + max_size
            b, c, h, w = batch_shape
            dtype = tensor_list[0].dtype
            device = tensor_list[0].device
            tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
            mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
            for img, pad_img, m in zip(tensor_list, tensor, mask):
                pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
                m[: img.shape[1], :img.shape[2]] = False
        else:
            raise ValueError('not supported')
        return NestedTensor(tensor, mask)
    
    # file_path: detr/models/backbone.py
    # 映射到特征图上
    mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
    
    y_embed = not_mask.cumsum(1, dtype=torch.float32) # 行方向累加
    x_embed = not_mask.cumsum(2, dtype=torch.float32) # 列方向累加
    

    e.g y_embed 如下所示


    后面在进行归一化操作, 如下所示:
            if self.normalize:
                eps = 1e-6
                y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
                x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
    

    下面公式中,self.num_pos_feats默认为128, 用torch.arrange是因为奇数维度和偶数维度是不一样的


    后面我们得到我们的pos_x以及pos_y(即行和列的编码)得到的向量为(2, 26, 25, 128)最后通过cat操作得到我们最终的位置编码pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
    我们的pos的shape为(2, 256, 24, 29)
    dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
    dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
    pos_x = x_embed[:, :, :, None] / dim_t
    pos_y = y_embed[:, :, :, None] / dim_t
    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
    return pos
    

    2.3 mask与编码模块

    首先通过

    src, mask = features[-1].decompose()
    

    将最后一次的特征取出及最后一层特征图所对应的padding。接着运行下面的代码进入transformer

    hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
    
    self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
    

    因为我们的src的shape为(2, 2048, 24, 29), 其第二个维度2048维度太大我们需要通过self.input_proj(卷积)将其转成小的维度转换成(2, 256, 24, 29)大小的特征,在进入transformer。这里的self.query_embed

    self.query_embed = nn.Embedding(num_queries, hidden_dim)
    

    还有这里的pos[-1][0]也是取最后一层(-1), ##########

    2.4 transformer

        def forward(self, src, mask, query_embed, pos_embed):
            # flatten NxCxHxW to HWxNxC
            bs, c, h, w = src.shape
            src = src.flatten(2).permute(2, 0, 1)
            pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
            query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
            mask = mask.flatten(1)
    
            tgt = torch.zeros_like(query_embed)
            memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
            hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                              pos=pos_embed, query_pos=query_embed)
            return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
    

    我们的src的shape之前已经说过了为(2, 256, 29, 24)通过
    src = src.flatten(2).permute(2, 0, 1)我们得到src的shape为(696, 2, 256) 这里696为序列大小, 256为特征通道。对应的pose_embed也相对应的转换成了大小为(696, 2, 256)的向量。这里转换后的query_embedshape大小为(100, 2, 256), 100对应的就是100个输出如何拿到适合的特征, 这样就形成了100个查找向量。这里的mask向量shape为(2, 696)tgt的向量大小为(100, 2, 256)

    1. Encoder
    下面我们需要重点说下encoder

    # 下面的mask,相当于后面不需要通过attention机制
    memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 
    
    class TransformerEncoderLayer(nn.Module):
    
        def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                     activation="relu", normalize_before=False):
            super().__init__()
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
            # Implementation of Feedforward model
            self.linear1 = nn.Linear(d_model, dim_feedforward)
            self.dropout = nn.Dropout(dropout)
            self.linear2 = nn.Linear(dim_feedforward, d_model)
    
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
            self.dropout1 = nn.Dropout(dropout)
            self.dropout2 = nn.Dropout(dropout)
    
            self.activation = _get_activation_fn(activation)
            self.normalize_before = normalize_before
    
        def with_pos_embed(self, tensor, pos: Optional[Tensor]):
            return tensor if pos is None else tensor + pos
    
        def forward_post(self,
                         src,
                         src_mask: Optional[Tensor] = None,
                         src_key_padding_mask: Optional[Tensor] = None,
                         pos: Optional[Tensor] = None):
            q = k = self.with_pos_embed(src, pos)
            src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                                  key_padding_mask=src_key_padding_mask)[0]
            src = src + self.dropout1(src2)
            src = self.norm1(src)
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
            src = src + self.dropout2(src2)
            src = self.norm2(src)
            return src
    
    # 原版Transformer 只在Encoder之前使用了Positional Encoding,而且是在输入上进行Positional Encoding,再把输入经transformation matrix变为
    # Query,Key和Value这几个张量。但是DETR在Encoder的每一个Multi-head Self-attention之前都使用了Positional Encoding,且只对Query和Key使
    # 用了Positional Encoding,即:只把维度为(HW,B,256) 维的位置编码与维度为(HW,B,256)维的Query和Key相加,而不与Value相加
    q = k = self.with_pos_embed(src, pos)
    
    蓝色表示需要基于位置编码的QK然后我们的v就是特征

    这样我们得出qk的shape都一样为(696, 2, 256), 经过self.attn

    # q, k 与position相加
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         return tensor if pos is None else tensor + pos
    
    src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
               key_padding_mask=src_key_padding_mask)[0] # 自注意力层的输出,自注意力权重`, 这里我们只要第一个
    

    得到src2的shape为(696, 2, 256), atten_mask都是None值(在NLP领域有用,在这里没有用), src_key_padding_mask表示的是padding的位置不做attention。我们得到两个返回值分别是自注意力层的输出,自注意力权重, 这里我们只要第一个, 第二个主要用来做可视化用的, 接着下面的操作和transformer一样。

            src = src + self.dropout1(src2)
            src = self.norm1(src)
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
            src = src + self.dropout2(src2)
            src = self.norm2(src)
    

    下面就是多层的传递了

    class TransformerEncoder(nn.Module):
    
        def __init__(self, encoder_layer, num_layers, norm=None):
            super().__init__()
            self.layers = _get_clones(encoder_layer, num_layers)
            self.num_layers = num_layers
            self.norm = norm
    
        def forward(self, src,
                    mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
            output = src
    
            for layer in self.layers:
                output = layer(output, src_mask=mask,
                               src_key_padding_mask=src_key_padding_mask, pos=pos)
    
            if self.norm is not None:
                output = self.norm(output)
    
            return output
    

    2. Decoder
    接着回到下面的代码

    tgt = torch.zeros_like(query_embed)
    memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
    hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                              pos=pos_embed, query_pos=query_embed)
    

    通过encoder我们得到了memory,下面我们需要进入到decoder中去。这里我们得到的memoryshape大小为(696, 2, 256)

    
    class TransformerDecoderLayer(nn.Module):
    
        def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                     activation="relu", normalize_before=False):
            super().__init__()
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
            self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
            # Implementation of Feedforward model
            self.linear1 = nn.Linear(d_model, dim_feedforward)
            self.dropout = nn.Dropout(dropout)
            self.linear2 = nn.Linear(dim_feedforward, d_model)
    
            self.norm1 = nn.LayerNorm(d_model)
            self.norm2 = nn.LayerNorm(d_model)
            self.norm3 = nn.LayerNorm(d_model)
            self.dropout1 = nn.Dropout(dropout)
            self.dropout2 = nn.Dropout(dropout)
            self.dropout3 = nn.Dropout(dropout)
    
            self.activation = _get_activation_fn(activation)
            self.normalize_before = normalize_before
    
        def with_pos_embed(self, tensor, pos: Optional[Tensor]):
            return tensor if pos is None else tensor + pos
    
        def forward_post(self, tgt, memory,
                         tgt_mask: Optional[Tensor] = None,
                         memory_mask: Optional[Tensor] = None,
                         tgt_key_padding_mask: Optional[Tensor] = None,
                         memory_key_padding_mask: Optional[Tensor] = None,
                         pos: Optional[Tensor] = None,
                         query_pos: Optional[Tensor] = None):
            q = k = self.with_pos_embed(tgt, query_pos)
            tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                                  key_padding_mask=tgt_key_padding_mask)[0]
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)
            tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                       key=self.with_pos_embed(memory, pos),
                                       value=memory, attn_mask=memory_mask,
                                       key_padding_mask=memory_key_padding_mask)[0]
            tgt = tgt + self.dropout2(tgt2)
            tgt = self.norm2(tgt)
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
            tgt = tgt + self.dropout3(tgt2)
            tgt = self.norm3(tgt)
            return tgt
    
        def forward_pre(self, tgt, memory,
                        tgt_mask: Optional[Tensor] = None,
                        memory_mask: Optional[Tensor] = None,
                        tgt_key_padding_mask: Optional[Tensor] = None,
                        memory_key_padding_mask: Optional[Tensor] = None,
                        pos: Optional[Tensor] = None,
                        query_pos: Optional[Tensor] = None):
            tgt2 = self.norm1(tgt)
            q = k = self.with_pos_embed(tgt2, query_pos)
            tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                                  key_padding_mask=tgt_key_padding_mask)[0]
            tgt = tgt + self.dropout1(tgt2)
            tgt2 = self.norm2(tgt)
            tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                       key=self.with_pos_embed(memory, pos),
                                       value=memory, attn_mask=memory_mask,
                                       key_padding_mask=memory_key_padding_mask)[0]
            tgt = tgt + self.dropout2(tgt2)
            tgt2 = self.norm3(tgt)
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
            tgt = tgt + self.dropout3(tgt2)
            return tgt
    
        def forward(self, tgt, memory,
                    tgt_mask: Optional[Tensor] = None,
                    memory_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
            if self.normalize_before:
                return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
                                        tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
            return self.forward_post(tgt, memory, tgt_mask, memory_mask,
                                     tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
    

    我们主要来看下post_forward函数, 还是要结合下面的pipline的图就很好理解了。

        def forward_post(self, tgt, memory,
                         tgt_mask: Optional[Tensor] = None,
                         memory_mask: Optional[Tensor] = None,
                         tgt_key_padding_mask: Optional[Tensor] = None,
                         memory_key_padding_mask: Optional[Tensor] = None,
                         pos: Optional[Tensor] = None,
                         query_pos: Optional[Tensor] = None):
            q = k = self.with_pos_embed(tgt, query_pos)
            tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                                  key_padding_mask=tgt_key_padding_mask)[0]
            tgt = tgt + self.dropout1(tgt2)
            tgt = self.norm1(tgt)
            tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                       key=self.with_pos_embed(memory, pos),
                                       value=memory, attn_mask=memory_mask,
                                       key_padding_mask=memory_key_padding_mask)[0]
            tgt = tgt + self.dropout2(tgt2)
            tgt = self.norm2(tgt)
            tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
            tgt = tgt + self.dropout3(tgt2)
            tgt = self.norm3(tgt)
            return tgt
    

    首先前面也说了将我们的qk置为0, 因为初始的时候我们qk无先验值。所以这里tgt就是为0, 其shape为(200, 2, 256)

    q = k = self.with_pos_embed(tgt, query_pos)
    
    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos
    

    接下来一样还是做个attention, 注意attentionself_attention是有区别的。

    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    
    tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
               key_padding_mask=tgt_key_padding_mask)[0]
    

    下面的代码直接看上述的结构图一一对应就很好理解了。

    3 Decoder Encoder

        def forward(self, src, mask, query_embed, pos_embed):
            # flatten NxCxHxW to HWxNxC
            bs, c, h, w = src.shape
            src = src.flatten(2).permute(2, 0, 1)
            pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
            query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
            mask = mask.flatten(1)
    
            tgt = torch.zeros_like(query_embed)
            memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
            hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                              pos=pos_embed, query_pos=query_embed)
            return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
    

    最终我们得到hsshape为(6, 100, 2, 256), 这里的6代表Decoder做了6次的结果。这样的好处就是在计算loss的时候,我们可以计算每一层的损失。2batch_size100代表要预测100query, 256代表的是每个query的维度。

    4 输出层

    hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
    outputs_class = self.class_embed(hs)
    outputs_coord = self.bbox_embed(hs).sigmoid()
    

    这里的hs的大小为(6, 2, 100, 256)

    self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
    
    self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
    

    得到outputs_class的shape为(6, 2, 100, 92), outputs_coord的shape为(6, 2, 100, 4), 这里的num_classes91 再加上一个背景91+1=92

    会得到 N=100个预测目标,包含类别和Bounding Box,当然这个100肯定是大于图中的目标总数的。如果不够100,则采用背景填充,计算loss时候回归分支分支仅仅计算有物体位置,背景集合忽略。所以,DETR输出张量的维度为输出的张量的维度是 (b,100,class+1) 和 (b,100,4)。对应COCO数据集来说, class+1=92 , 4 指的是每个预测目标归一化的 (cx,cy,w,h) 。归一化就是除以图片宽高进行归一化。

    2.3 LOSS
    这里的损失主要分为三个大类,分别是分类损失以及回归损失第三个则是giou损失。

    class SetCriterion(nn.Module):
        """ This class computes the loss for DETR.
        The process happens in two steps:
            1) we compute hungarian assignment between ground truth boxes and the outputs of the model
            2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
        """
        def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
            """ Create the criterion.
            Parameters:
                num_classes: number of object categories, omitting the special no-object category
                matcher: module able to compute a matching between targets and proposals
                weight_dict: dict containing as key the names of the losses and as values their relative weight.
                eos_coef: relative classification weight applied to the no-object category
                losses: list of all the losses to be applied. See get_loss for list of available losses.
            """
            super().__init__()
            self.num_classes = num_classes
            self.matcher = matcher
            self.weight_dict = weight_dict
            self.eos_coef = eos_coef
            self.losses = losses
            empty_weight = torch.ones(self.num_classes + 1)
            empty_weight[-1] = self.eos_coef
            self.register_buffer('empty_weight', empty_weight)
    
        def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
            """Classification loss (NLL)
            targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
            """
            assert 'pred_logits' in outputs
            src_logits = outputs['pred_logits']
    
            idx = self._get_src_permutation_idx(indices)
            target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
            target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                        dtype=torch.int64, device=src_logits.device)
            target_classes[idx] = target_classes_o
    
            loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
            losses = {'loss_ce': loss_ce}
    
            if log:
                # TODO this should probably be a separate loss, not hacked in this one here
                losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
            return losses
    
        @torch.no_grad()
        def loss_cardinality(self, outputs, targets, indices, num_boxes):
            """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
            This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
            """
            pred_logits = outputs['pred_logits']
            device = pred_logits.device
            tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
            # Count the number of predictions that are NOT "no-object" (which is the last class)
            card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
            card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
            losses = {'cardinality_error': card_err}
            return losses
    
        def loss_boxes(self, outputs, targets, indices, num_boxes):
            """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
               targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
               The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
            """
            assert 'pred_boxes' in outputs
            idx = self._get_src_permutation_idx(indices)
            src_boxes = outputs['pred_boxes'][idx]
            target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
    
            loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
    
            losses = {}
            losses['loss_bbox'] = loss_bbox.sum() / num_boxes
    
            loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
                box_ops.box_cxcywh_to_xyxy(src_boxes),
                box_ops.box_cxcywh_to_xyxy(target_boxes)))
            losses['loss_giou'] = loss_giou.sum() / num_boxes
            return losses
    
        def loss_masks(self, outputs, targets, indices, num_boxes):
            """Compute the losses related to the masks: the focal loss and the dice loss.
               targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
            """
            assert "pred_masks" in outputs
    
            src_idx = self._get_src_permutation_idx(indices)
            tgt_idx = self._get_tgt_permutation_idx(indices)
            src_masks = outputs["pred_masks"]
            src_masks = src_masks[src_idx]
            masks = [t["masks"] for t in targets]
            # TODO use valid to mask invalid areas due to padding in loss
            target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
            target_masks = target_masks.to(src_masks)
            target_masks = target_masks[tgt_idx]
    
            # upsample predictions to the target size
            src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                    mode="bilinear", align_corners=False)
            src_masks = src_masks[:, 0].flatten(1)
    
            target_masks = target_masks.flatten(1)
            target_masks = target_masks.view(src_masks.shape)
            losses = {
                "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
                "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
            }
            return losses
    
        def _get_src_permutation_idx(self, indices):
            # permute predictions following indices
            batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
            src_idx = torch.cat([src for (src, _) in indices])
            return batch_idx, src_idx
    
        def _get_tgt_permutation_idx(self, indices):
            # permute targets following indices
            batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
            tgt_idx = torch.cat([tgt for (_, tgt) in indices])
            return batch_idx, tgt_idx
    
        def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
            loss_map = {
                'labels': self.loss_labels,
                'cardinality': self.loss_cardinality,
                'boxes': self.loss_boxes,
                'masks': self.loss_masks
            }
            assert loss in loss_map, f'do you really want to compute {loss} loss?'
            return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
    
        def forward(self, outputs, targets):
            """ This performs the loss computation.
            Parameters:
                 outputs: dict of tensors, see the output specification of the model for the format
                 targets: list of dicts, such that len(targets) == batch_size.
                          The expected keys in each dict depends on the losses applied, see each loss' doc
            """
            outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
    
            # Retrieve the matching between the outputs of the last layer and the targets
            indices = self.matcher(outputs_without_aux, targets)
    
            # Compute the average number of target boxes accross all nodes, for normalization purposes
            num_boxes = sum(len(t["labels"]) for t in targets)
            num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
            if is_dist_avail_and_initialized():
                torch.distributed.all_reduce(num_boxes)
            num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
    
            # Compute all the requested losses
            losses = {}
            for loss in self.losses:
                losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
    
            # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
            if 'aux_outputs' in outputs:
                for i, aux_outputs in enumerate(outputs['aux_outputs']):
                    indices = self.matcher(aux_outputs, targets)
                    for loss in self.losses:
                        if loss == 'masks':
                            # Intermediate masks losses are too costly to compute, we ignore them.
                            continue
                        kwargs = {}
                        if loss == 'labels':
                            # Logging is enabled only for the last layer
                            kwargs = {'log': False}
                        l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                        l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                        losses.update(l_dict)
    
            return losses
    

    我们首先来看一下forward函数

        def forward(self, outputs, targets):
            """ This performs the loss computation.
            Parameters:
                 outputs: dict of tensors, see the output specification of the model for the format
                 targets: list of dicts, such that len(targets) == batch_size.
                          The expected keys in each dict depends on the losses applied, see each loss' doc
            """
            outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
    
            # Retrieve the matching between the outputs of the last layer and the targets
            indices = self.matcher(outputs_without_aux, targets)
    
            # Compute the average number of target boxes accross all nodes, for normalization purposes
            num_boxes = sum(len(t["labels"]) for t in targets)
            num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
            if is_dist_avail_and_initialized():
                torch.distributed.all_reduce(num_boxes)
            num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
    
            # Compute all the requested losses
            losses = {}
            for loss in self.losses:
                losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
    
            # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
            if 'aux_outputs' in outputs:
                for i, aux_outputs in enumerate(outputs['aux_outputs']):
                    indices = self.matcher(aux_outputs, targets)
                    for loss in self.losses:
                        if loss == 'masks':
                            # Intermediate masks losses are too costly to compute, we ignore them.
                            continue
                        kwargs = {}
                        if loss == 'labels':
                            # Logging is enabled only for the last layer
                            kwargs = {'log': False}
                        l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                        l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                        losses.update(l_dict)
    
            return losses
    

    outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}表示是否Decoder每一层都计算损失(前面说过Decoder总共有6层)。
    还有一点就是我们有100个框,但是GT只有2个,那我们如何匹配计算损失呢?这里我们用到的是匈牙利匹配的算法, 这里的函数我们使用的代码

    # Retrieve the matching between the outputs of the last layer and the targets
    indices = self.matcher(outputs_without_aux, targets)
    

    具体的matcher代码如下:

    class HungarianMatcher(nn.Module):
        """This class computes an assignment between the targets and the predictions of the network
    
        For efficiency reasons, the targets don't include the no_object. Because of this, in general,
        there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
        while the others are un-matched (and thus treated as non-objects).
        """
    
        def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
            """Creates the matcher
    
            Params:
                cost_class: This is the relative weight of the classification error in the matching cost
                cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
                cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
            """
            super().__init__()
            self.cost_class = cost_class
            self.cost_bbox = cost_bbox
            self.cost_giou = cost_giou
            assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
    
        @torch.no_grad()
        def forward(self, outputs, targets):
            """ Performs the matching
    
            Params:
                outputs: This is a dict that contains at least these entries:
                     "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                     "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
    
                targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                     "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                               objects in the target) containing the class labels
                     "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
    
            Returns:
                A list of size batch_size, containing tuples of (index_i, index_j) where:
                    - index_i is the indices of the selected predictions (in order)
                    - index_j is the indices of the corresponding selected targets (in order)
                For each batch element, it holds:
                    len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
            """
            bs, num_queries = outputs["pred_logits"].shape[:2]
    
            # We flatten to compute the cost matrices in a batch
            out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
            out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
    
            # Also concat the target labels and boxes
            tgt_ids = torch.cat([v["labels"] for v in targets])
            tgt_bbox = torch.cat([v["boxes"] for v in targets])
    
            # Compute the classification cost. Contrary to the loss, we don't use the NLL,
            # but approximate it in 1 - proba[target class].
            # The 1 is a constant that doesn't change the matching, it can be ommitted.
            cost_class = -out_prob[:, tgt_ids]
    
            # Compute the L1 cost between boxes
            cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    
            # Compute the giou cost betwen boxes
            cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
    
            # Final cost matrix
            C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
            C = C.view(bs, num_queries, -1).cpu()
    
            sizes = [len(v["boxes"]) for v in targets]
            indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
    
    
    def build_matcher(args):
        return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
    

    下面是3个loss计算

    # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    # but approximate it in 1 - proba[target class].
    # The 1 is a constant that doesn't change the matching, it can be ommitted.
    cost_class = -out_prob[:, tgt_ids]
    
    # Compute the L1 cost between boxes
    cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    
    # Compute the giou cost betwen boxes
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
    

    基于匈牙利匹配算法我们选择loss最小的匹配方法.

    return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
    

    返回的是最优匹配的索引,然后我们就可以基于这些计算我们需要的损失值了。
    参考:
    [1] 搞懂视觉 Transformer 原理和代码,看这篇技术综述就够了

    相关文章

      网友评论

          本文标题:Transformer在图像中的运用(四)DETR(DEtect

          本文链接:https://www.haomeiwen.com/subject/gfewcrtx.html