logits是网络的输出,logits.shape=(batch_size, w, h, 21),21类语义标签。
pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)#shape=(?, ?, ?, 1)
我们用numpy解释argmax和expand_dims这两个函数:
import numpy as np
#(2, 2, 4, 3)
x = np.array([[[[31,20,10],
[20,43,30],
[40,10,62],
[40,60,76]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]],
[[[10,22,10],
[20,30,81],
[40,10,62],
[40,65,30]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]]])
#(2, 2, 4)
y1 = np.argmax(x, axis=3)
'''[[[0 1 2 2]
[1 0 0 2]]
[[1 2 2 1]
[1 0 0 2]]]'''
#(2, 2, 4, 1)
y2 = np.expand_dims(y1, axis=3)
'''[[[[0]
[1]
[2]
[2]]
[[1]
[0]
[0]
[2]]]
[[[1]
[2]
[2]
[1]]
[[1]
[0]
[0]
[2]]]]'''
y2的值是每个一维列表的最大值的下标,如第一个值为0,是因为[31,20,10]中最大元素31的下标是0。
batch_size为1时的网络输出:
[[[31,20,10],
[20,43,30],
[40,10,62],
[40,60,76]],
[[10,72,20],
[81,30,40],
[97,50,70],
[40,50,68]]]
注意图像大小是2×4,而不是4×3或3×4。所以[31,20,10]是第一个像素被归类为第0类、第1类、第2类的概率。因为31最大,所以该像素的语义标签被归类为0。这样就可以解释y2:
[[[0]
[1]
[2]
[2]]
[[1]
[0]
[0]
[2]]]
batch_size=1时,它是指一个2×4的图像的第(0,0)个像素标签为0、第(0,1)个像素标签为1、... 、第(1,3)个像素标签为2。
得到预测分割的0,1,2结果后,分别对pred,ground truth 进行one-hot后评估。
def one_hot(label):
'''Convert label (d,h,w) to one-hot label (d,h,w,num_class).
'''
num_class = np.max(label) + 1
return np.eye(num_class)[label]
gt_one_hot = one_hot(gt) # 将 GT one-hot
pred_one_hot = one_hot(pred) # 将预测的segmention one-hot
# 单独的每个类。0对应于background类(忽略)。
csf_pred = pred_one_hot[:,:,:,1]
csf_label = label_one_hot[:,:,:,1]
gm_pred = pred_one_hot[:,:,:,2]
gm_label = label_one_hot[:,:,:,2]
csf_dr = dice_ratio(csf_pred, csf_label)
gm_dr = dice_ratio(gm_pred, gm_label)
print('--->avg:', np.mean([csf_dr, gm_dr))
代码来自 Non-Local U-Net
网友评论