1、训练历史
import matplotlib.pyplot as plt
def showhist(history):
fig,ax = plt.subplots(1,2,figsize=(8,2))
ax[0].plot(history.history['accuracy'],label='train_acc')
ax[0].plot(history.history['val_accuracy'],label='val_acc')
ax[0].set_title('train and validation accuracy')
ax[0].set_ylabel('Accuracy')
ax[0].set_xlabel('Epoch')
ax[0].legend( loc='lower right')
ax[1].plot(history.history['loss'],label='train_loss')
ax[1].plot(history.history['val_loss'],label='val_loss')
ax[1].set_title('train and vaidation loss')
ax[1].set_ylabel('loss')
ax[1].set_xlabel('Epoch')
ax[1].legend( loc='upper right')
plt.show()
2、结果评价
def model_assess(y_true, y_prob, roc=True, pr=False):
'''
pass
'''
re = {}
y_true = np.array(y_true, dtype=int)
y_pred = np.array(y_prob + 0.5, dtype=int)
re['accuracy'] = round(metrics.accuracy_score(y_true, y_pred), 3)
re['precision'] = round(metrics.precision_score(y_true, y_pred), 3)
re['recall'] = round(metrics.recall_score(y_true, y_pred), 3)
re['f1'] = round(metrics.f1_score(y_true, y_pred), 3)
re['auc'] = round(metrics.roc_auc_score(y_true, y_prob), 3)
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_prob)
plt.plot(fpr, tpr, label='RF:' + str(re['auc']))
plt.title('ROC curve')
plt.xlabel('fpr')
plt.ylabel('tpr')
# precision, recall, thresholds = metrics.precision_recall_curve(y_true,y_prob)
# plt.plot(precision,recall,label='precison_and_recall')
plt.legend()
return re
网友评论