美文网首页
机器学习一些代码记录

机器学习一些代码记录

作者: IT_小马哥 | 来源:发表于2022-07-24 19:01 被阅读0次

    计算多分类时的每个类别的F1

    • 接口
    sklearn.metrics.classification_report(y_true, y_pred, labels=None, target_names=None, sample_weight=None, digits=2, output_dict=False)
    

    示例:

    from sklearn.metrics import classification_report
    y_true = [0,0, 1, 2, 2, 2, 0]
    y_pred = [0, 1, 0, 2, 2, 1, 0]
    target_names = ['dog', 'pig', 'cat']
    result = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
    print(result)
    
    image.png

    pytorch 使用K-折交叉验证

    pytorch 使用K-折交叉验证

    核心代码

      # Define the K-fold Cross Validator
      kfold = KFold(n_splits=k_folds, shuffle=True)
    
      # K-fold Cross Validation model evaluation
      for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset))
        
        # Sample elements randomly from a given list of ids, no replacement.
        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
        
        # Define data loaders for training and testing data in this fold
        trainloader = torch.utils.data.DataLoader(
                          dataset, 
                          batch_size=10, sampler=train_subsampler)
        testloader = torch.utils.data.DataLoader(
                          dataset,
                          batch_size=10, sampler=test_subsampler)
    

    相关文章

      网友评论

          本文标题:机器学习一些代码记录

          本文链接:https://www.haomeiwen.com/subject/vylwirtx.html