美文网首页
class_weight,字典,下采样,list, idx, 训

class_weight,字典,下采样,list, idx, 训

作者: 菌子甚毒 | 来源:发表于2022-05-16 15:59 被阅读0次
  1. 对于:
    torch.arange(0,10).reshape(2,-1)==torch.arange(0,10).reshape(-1,5)
    没有区别。
  2. 读字典:
dict.values()
dict.keys()
dict['key1']
list(dict)  # keys
  1. list的拆分:
a,b = [['a','b','c'],[1,2,3]]
print(a,'\n',b)
"""
output:
['a', 'b', 'c'] 
 [1, 2, 3]
"""
  1. random.sample()
# eg1
a = ['a','b','c','d','e','f','g']
random.sample(a,4)
"""
output:
['a', 'd', 'f', 'c']
"""
#eg2
b = {'a':1,'b':2,'c':3,'d':4,'e':5}
random.sample(b,2)
"""
output:
TypeError: Population must be a sequence or set. For dicts, use list(d).
故改为:
"""
{each_key:b[each_key] for each_key in random.sample(list(b),2)}
"""
{'b': 2, 'e': 5}
"""

下采样代码:

# 创建数据字典
c = {'a':[i for i in range(10)],'b':[i for i in range(100,110,1)]}

# 展平字典的values
value_flatten = [i for each in list(c.values()) for i in each]
# sum(c.values(),[])

# 采样
sampled_value = random.sample(value_flatten,10)

#寻位还原字典形式
zeros = {}
for each_value in sampled_value:
    for each_keys in c:
        if each_value in c[each_keys]:
            if each_keys not in list(zeros):
                zeros[each_keys] = []
            zeros[each_keys].append(each_value)    
"""
各个对象采样加起来一共10个。
采样前:
{'a': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 'b': [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]}
采样后:
{'a': [6, 5, 3, 0, 4], 'b': [108, 100, 104, 105, 109]}
"""
  1. class_weight
# 创建数据
zeros = {str(key):random.sample(range(10),random.randint(5,8)) for key in range(300,310)}
ones = {str(key):random.sample(range(10),random.randint(5,8)) for key in range(400,415)}
"""
zeros:
{'300': [4, 1, 8, 5, 0],
  '301': [9, 6, 2, 1, 3, 5],
  ...
  '309': [6, 3, 7, 4, 8, 0, 5]},
ones:
 {'400': [6, 2, 1, 5, 9, 7],
  '401': [0, 3, 6, 7, 8, 9, 4],
  '402': [1, 4, 5, 8, 3],
  ...
  '414': [9, 0, 6, 2, 3, 4, 5, 7]}
"""

(1) class weight

instance_of_zero = sum(len(i) for i in zeros.values())  # 64
instance_of_one = sum(len(i) for i in ones.values())    # 88
min_class = min(len(zeros), len(ones))  # 10
max_class = max(len(zeros), len(ones))  # 15
class_weight = min_class / max_class    # 0.667
index_to_alter = [len(zeros), len(ones)].index(max_class)   # 1。读出哪一类的subject更多
weights = [1, 1]    
weights[index_to_alter] = class_weight  # [1,0.667]

# 以字典存储每个sample的weight
weights_dict = {}
for i in zeros:
    weights_dict[i] = weights[0]
for i in ones:
    weights_dict[i] = weights[1]
"""
weights_dict:
{'300': 1,
 '301': 1,
 '302': 1,
 ....
 '400': 0.6666666666666666,
 '401': 0.6666666666666666,
 '402': 0.6666666666666666,
 ...}
"""

(2) micro weight

WEIGHT_TYPE = 'micro'
min_zeros = min(len(i) for i in zeros.values()) # 5
min_ones = min(len(i) for i in ones.values())   # 5
micro_weights = {}
for i in zeros:
    tmp = min_zeros / len(zeros[i])
    micro_weights[i] = tmp
for i in ones:
    tmp = min_ones / len(ones[i])
    micro_weights[i] = tmp
weights_dict = micro_weights

(3) combine(1)&(2)

# 以上两种结合
WEIGHT_TYPE = 'both'
min_zeros = min(len(i) for i in zeros.values()) # 5
min_ones = min(len(i) for i in ones.values())   # 5
micro_weights = {}
for i in zeros:
    tmp = min_zeros / len(zeros[i])
    if WEIGHT_TYPE == 'both':
        tmp = tmp * weights[0]
    micro_weights[i] = tmp
for i in ones:
    tmp = min_ones / len(ones[i])
    if WEIGHT_TYPE == 'both':
        tmp = tmp * weights[1]
    micro_weights[i] = tmp
weights_dict = micro_weights

(4) class weight (segment_based)

WEIGHT_TYPE = 'instance'
if WEIGHT_TYPE == 'instance':
    min_tmp = min(instance_of_zero, instance_of_one)
    max_tmp = max(instance_of_zero, instance_of_one)

    tmp_weight = min_tmp / max_tmp
    index_to_alter = [instance_of_zero, instance_of_one].index(max_tmp)

    weights = [1, 1]
    weights[index_to_alter] = tmp_weight

    instance_weights = {}
    for i in zeros:
        instance_weights[i] = weights[0]
    for i in ones:
        instance_weights[i] = weights[1]
    weights_dict = instance_weights
  1. indx
import numpy as np
a = np.array([i for i in range(10)])
idx = [5,6,7]
a[idx]
# output: array([5, 6, 7])
  1. 训练的数据格式
    • float->torch.Tensor(x)
    • int->torch.LongTensor(x)
def create_tensor_data(x, cuda):
    """
    Converts the data from numpy arrays to torch tensors

    Inputs
        x: The input data
        cuda: Bool - Set to true if using the GPU

    Output
        x: Data converted to a tensor
    """
    if 'float' in str(x.dtype):
        x = torch.Tensor(x)
    elif 'int' in str(x.dtype):
        x = torch.LongTensor(x)
    else:
        raise Exception("Error!")

    if cuda:
        x = x.cuda()

    return x

相关文章

网友评论

      本文标题:class_weight,字典,下采样,list, idx, 训

      本文链接:https://www.haomeiwen.com/subject/zjgvurtx.html