美文网首页
SegCaps路由详解

SegCaps路由详解

作者: zelda2333 | 来源:发表于2020-12-03 15:28 被阅读0次

论文:Capsules for Object Segmentation
代码:https://github.com/lalonderodney/SegCaps

路由.png

了解 capsule_layers.py

我们以第一层 primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same', routings=1, name='primarycaps')(conv1_reshaped) 为例。

Input shape: batch_size, H, W, C

conv1_reshaped = (512,512,1,16)

primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same', routings=1, name='primarycaps')
(conv1_reshaped)

class ConvCapsuleLayer

__init__(), build() 属于初始化的函数

input_height = input_shape[1].................................................512
input_width = input_shape[2]..................................................512
input_num_capsule = input_shape[3].....................................1
input_num_atoms = input_shape[4]........................................16 维度

W =shape(kernel_size, kernel_size,input_num_atoms,num_capsule * num_atoms)
................................................................................................(5,5,1,2*16)
b = shape(1, 1, num_capsule, num_atoms)...........................(1,1,2,16)

call()

input_transposed = tf.transpose([3, 0, 1, 2, 4]).......................(input_num_capsule, N, input_height, input_width, input_num_atoms)(1, 512, 512, 16)

input_shape = K.shape(input_transposed).............................(input_num_capsule, N, input_height, input_width, input_num_atoms)(1, 512, 512, 16)

input_tensor_reshaped = K.reshape(input_transposed, [input_shape[0] * input_shape[1], input_height, input_width, input_num_atoms]).........................(input_num_capsule * N, input_height, input_width, input_num_atoms)(1*N, 512,512,16)

conv = K.conv2d(input_tensor_reshaped, W, (strides, strides), padding=padding, data_format='channels_last')...................................................(1*N, 256, 256, num_capsule * num_atoms=2*16)

votes_shape = K.shape(conv)................................................(1*N, 256,256,num_capsule * num_atoms=2*16)

_, conv_height, conv_width, _ = conv.get_shape()..............(1*N, 256, 256, num_capsule * num_atoms=2*16)

votes = K.reshape(conv, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], num_capsule, num_atoms]).....................................................(N, input_num_capsule, 256, 256, num_capsule, num_atoms)(N,1, 256, 256, 2, 16)

logit_shape = K.stack([input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], num_capsule]).........................................................................shape=(5,) 可以用logit_shape[0]去到里面具体的数字(N, input_num_capsule, 256, 256, num_capsule)(N, 1, 256, 256, 2)

biases_replicated = K.tile(b, [conv_height.value, conv_width.value, 1, 1]) ...........................将 b 里面的内容按照每一维的大小进行复制(256,256,num_capsule,num_atoms)(256,256,2,16)

activations = update_routing(
votes=votes,........................................(N, input_num_capsule, 256, 256, num_capsule, num_atoms)(N,1, 256, 256, 2, 16)
biases=biases_replicated,...................(256,256,num_capsule,num_atoms)(256,256,2,16)
logit_shape=logit_shape,....................(N, input_num_capsule, 256, 256, num_capsule)(N, 1, 256, 256, 2)
num_dims=6,
input_dim=input_num_capsule,............1
output_dim=num_capsule,....................2
num_routing=routings)..............................1/3
return activations


经过初始化我们得到
votes:
shape [N, input_num_capsule 1, height 256, width 256, num_capsule 2, num_atoms 16]
内容为 5*5卷积后的特征结果

logit_shape:
shape [5, ]
内容为 [N, input_num_capsule 1, height 256, width 256, num_capsule 2]

biases:
shape [height 256, width 256, num_capsule 2, num_atoms 16]
内容为[0,1]之间的随机数


update_routing

votes_trans = tf.transpose(votes, [5, 0, 1, 2, 3, 4]).....................(num_atoms, N, input_num_capsule, 256, 256, num_capsule)(16, N,1, 256, 256, 2)

_, _, _, height, width, caps = votes_trans.get_shape()............(height, width, caps)(256, 256, 2)

关于TensorArraywhile_loop

activations = tf.TensorArray(dtype=tf.float32, size=num_routing, clear_after_read=False)

logits = tf.fill(logit_shape, 0.0)..................................................shape为[N,input_num_capsule 1,height 256,weight 256, out num_capsule 2] 内容为 0

i = tf.constant(0, dtype=tf.int32)

_, logits, activations = tf.while_loop(
   lambda i, logits, activations: i < num_routing,  # 判断继续执行循环的条件
   _body,   # 每个循环体内执行的操作
   loop_vars=[i, logits, activations],  # 循环的起始状态
   swap_memory=True)

循环体_body()

_body(i, logits, activations):
    """Routing while loop."""
    # route: [N, input_num_capsule, num_capsule, num_atom]
    # 得到里面每个值的权重
    # logits 相当于原始胶囊网络的 b,route 相当于 c
    route = tf.nn.softmax(logits, dim=-1)
    # c_r*u1      preactivate_unrolled:[num_atom, N,  input_num_capsule, 256,256,num_capsule]
    preactivate_unrolled = route * votes_trans
    # preact_trans:[N, input_num_capsule, 256,256,num_capsule,num_atom]
    preact_trans = tf.transpose(preactivate_unrolled, [1, 2, 3, 4, 5, 0])
    # 将input_num_capsule个胶囊结果相加,类似于原始胶囊网络的加和,只是多了个偏置
    preactivate = tf.reduce_sum(preact_trans, axis=1) + biases
    activation = _squash(preactivate)
    # 存储每次路由的结果
    activations = activations.write(i, activation)
    # reshape每次得到的结果
    act_3d = K.expand_dims(activation, 1)
    tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
    tile_shape[1] = input_dim
    # ar=squashing(cr * u)
    act_replicated = tf.tile(act_3d, tile_shape)
    # ar * u
    distances = tf.reduce_sum(votes * act_replicated, axis=-1)
    # b_r+1 = br + ar * u
    logits += distances
    return (i + 1, logits, activations)

return K.cast(activations.read(num_routing - 1), dtype='float32')

相关文章

  • SegCaps路由详解

    论文:Capsules for Object Segmentation[https://arxiv.org/pdf...

  • django路由url.py详解_Django学习笔记(四)-p

    django路由url.py详解_Django学习笔记(四)-python3 url路由方式: 1. patter...

  • Vue-router

    Vue路由详解 一、Vue Router 是 Vue.js 官方的路由管理器。它和 Vue.js 的核心深度集成,...

  • 18张图带你详解IP路由表七大要素:路由前缀、协议类型、优先级等

    IP 路由表 上次有写过一篇《20张图深度详解MAC地址表、ARP表、路由表 》的文章,里面有提到路由表,那么什么...

  • vue 篇章一

    参考文献: 官网地址Vue事件修饰符详解mvvm子路由路由器起步props 与 打data区别ajaxgithub...

  • 蓝图

    就是做路由分发的 详解:https://blog.csdn.net/weixin_41973615/article...

  • Flutter路由详解

    【声明:】本文是作者AWeiLoveAndroid原创,版权归作者 AWeiLoveAndroid 所有,侵权必究...

  • Vue路由详解

    Vue路由详解 对于前端来说,其实浏览器配合超级连接就很好的实现了路由功能。但是对于单页面应用来说,浏览器和超级连...

  • beego 路由详解

    beego路由设置 beego存在三种方式的路由:固定路由、正则路由、自动路由。下面就详细说一下如何使用这三种路由...

  • Nginx路由详解

    本文总结Nginx的location配置策略。结合案例说明location的用法。 location配置语法:lo...

网友评论

      本文标题:SegCaps路由详解

      本文链接:https://www.haomeiwen.com/subject/ydsmwktx.html