美文网首页
训练与测试时的One-hot编码一致问题

训练与测试时的One-hot编码一致问题

作者: 无量儿 | 来源:发表于2023-03-20 13:59 被阅读0次

神经网络训练模型已落地,所有变量都已释放。现在加载模型对测试集进行测试,如何保证测试集One-hot编码与之前训练的模型One-hot编码一致

在测试阶段,我们需要保证测试集的One-hot编码与之前训练模型使用的编码方式一致,以确保正确性和可靠性。以下是一些常见的方法:

1.使用与训练集相同的编码方式:在训练神经网络时,我们通常会对训练集进行One-hot编码处理,这个编码方式通常会在训练的前几步定义。在测试集上进行测试时,我们可以使用与训练集相同的编码方式进行处理,以确保一致性。

2.保存编码方式并在测试时重新加载:如果您的模型使用的是独立于训练集的编码方式,您可以将其保存为文件并在测试时重新加载。例如,如果您使用pickle或json格式保存了编码方式,可以使用相应的函数重新加载并在测试集上使用。

3.将编码方式嵌入到模型中:在某些情况下,您可以将编码方式嵌入到模型中。这可以通过在模型中添加额外的层来实现,这些层可以处理数据编码和解码,以确保编码一致性。在测试时,您可以使用这些层来编码测试数据,并使用与训练时相同的方法进行解码。

总的来说,保持一致性是关键,不同的方法可能适用于不同的情况。在测试阶段,一定要确保您使用的编码方式与之前训练时使用的编码方式一致。

以下是对上述第2点的代码实现,并且为两个特征的情况:

import pickle
import numpy as np
from sklearn.preprocessing import OneHotEncoder

onehot_encoder1 = OneHotEncoder()
onehot_encoder2 = OneHotEncoder()
np1 = np.array(['CTP733', 'CTP728', 'CTF382']).reshape(-1, 1)  # 训练时的输入
np2 = np.array(['WIFI', '5G', 'SUB6', 'Cell']).reshape(-1, 1)  # 训练时的输入
X_train_encoded1 = onehot_encoder1.fit_transform(np1)
X_train_encoded2 = onehot_encoder2.fit_transform(np2)
pickle.dump(onehot_encoder1, open('X_train_encoded1.pkl', 'wb'))  # 将编码方式保存在本地
pickle.dump(onehot_encoder2, open('X_train_encoded2.pkl', 'wb'))  # 将编码方式保存在本地
print(X_train_encoded1.toarray())
print(X_train_encoded2.toarray())
print('-'*40)

X_train_encoded_pre1 = pickle.load(open('X_train_encoded1.pkl', 'rb'))  # 预测时加载训练的编码方式
X_train_encoded_pre2 = pickle.load(open('X_train_encoded2.pkl', 'rb'))  # 预测时加载训练的编码方式
np_pre1 = np.array(['CTF382', 'CTF382', 'CTP728', 'CTP733']).reshape(-1, 1)  # 预测时的输入
np_pre2 = np.array(['SUB6']).reshape(-1, 1)  # 预测时的输入
result1 = X_train_encoded_pre1.transform(np_pre1).toarray()
result2 = X_train_encoded_pre2.transform(np_pre2).toarray()
print(result1)
print(result2)

相关文章

网友评论

      本文标题:训练与测试时的One-hot编码一致问题

      本文链接:https://www.haomeiwen.com/subject/tkhyrdtx.html