美文网首页English BootCamp英语训练营
代码实现(5)Relation Network for Few-

代码实现(5)Relation Network for Few-

作者: 续袁 | 来源:发表于2019-08-06 18:51 被阅读0次

    1.环境要求

    (1)Pytorch 1.0
    (2)python 3.6
    (3) numpy
    (4) scipy
    (5)matplotlib
    (6)torchvision
    (7)PIL (高版本的Python安装pillow)

    1.2

    conda install torchvision -c pytorch
    
    

    2. 代码运行

    2.1 问题:原程序是在GPU,改为CPU

    # 第一步: 注释掉一下两行代码
    
     # feature_encoder.cuda(GPU)
       # relation_network.cuda(GPU)
    # 第二步: 添加参数 ,map_location='cpu'
    RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.
    # 解决方法:添加参数 ,map_location='cpu'
        if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
            feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
            print("load feature encoder success")
        if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")):
            relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"),map_location='cpu'))
            print("load relation network success")
    
    

    2.2 问题:KeyError: '..\datas\omniglot_resized'

     Linux和window路径的转换
    解决方法:把'/'改成'\\'即可 
    def get_class(self, sample):
            return os.path.join(*sample.split('\\')[:-1]) 
    

    2.3 问题:报错信息:

    File "/LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 193, in main
        torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1))
    RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 'index'
    
    解决方法: 在前面加一句 : batch_labels = batch_labels.long()
    

    2.4 IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

    报错信息:
    File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
        print("episode:",episode+1,"loss",loss.data[0])
    IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    
    按要求改成
    print("episode:", episode + 1, "loss", loss.item())
    就可以了
    

    2.5 问题

      File "C:/Users/xpb/PycharmProjects/LearningToCompare_FSL-master/omniglot/omniglot_train_one_shot.py", line 268, in <listcomp>
        rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)]
    RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'other'
    
    解决方法: 在前面加上
    predict_labels = predict_labels.long()
    test_labels = test_labels.long()
    

    3.代码解读

    代码

    [1] floodsung/LearningToCompare_FSL
    [2] prolearner/LearningToCompareTF

    参考资料

    [1] torchvision库简介(翻译)
    [2] Pytorch——计算机视觉工具包:torchvision
    [3] Python---python3.7.0---如何安装PIL
    [4] Python图像处理PIL各模块详细介绍

    问题解决

    [0] Learning to Compare: Relation Network 源码调试
    [1] 关于Python读取文件的路径中斜杠问题
    [2] python把路径中反斜杠''变为'/'
    [3] # python路径拼接os.path.join()函数的用法
    [4] RuntimeError: Attempting to deserialize object on CUDA device 2 but torch.cuda.device_count() is 1
    [5] Pytorch的GPU计算(cuda)

    论文

    [1] Learning to Compare: Relation Network for Few-Shot Learning

    相关文章

      网友评论

        本文标题:代码实现(5)Relation Network for Few-

        本文链接:https://www.haomeiwen.com/subject/dbnwdctx.html