import numpy as np
# 根据计算原理:t_list[i] * clip_norm / max(global_norm, clip_norm)
# 生成0-9之间的数组成的列表
init_t_list = np.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# 方式1:使用np自带的函数
l2 = np.linalg.norm(init_t_list)
print(l2)
# 方式2:手写实现方式
l2_ = np.sqrt(np.sum(np.square(init_t_list)))
print(l2_)
# 假设裁剪规约数等于5.0
clip_norm = 5.0
# 求裁剪后的值
t_list = init_t_list * clip_norm / max(l2, clip_norm)
print(t_list)
# 裁剪后L2值
t_list_l2 = np.linalg.norm(t_list)
print(t_list_l2)
# 输出结果
16.8819430161
16.8819430161
[ 0. 0.29617444 0.59234888 0.88852332 1.18469776 1.48087219
1.77704663 2.07322107 2.36939551 2.66556995]
5.0
网友评论