美文网首页
python -- argparse

python -- argparse

作者: DeepWeaver | 来源:发表于2017-12-05 21:51 被阅读26次
    def DataLoader(data_place):
        """
        Define dataloder which is applicable to your data
    
        ### ouput
        images : 4 dimension tensor (the number of image x channel x image_height x image_width)
        id_labels : one-hot vector with Nd dimension
        pose_labels : one-hot vetor with Np dimension
        Nd : the nuber of ID in the data
        Np : the number of discrete pose in the data
        Nz : size of noise vector (Default in the paper is 50)
        """
        # Nd = []
        # Np = []
        # Nz = []
        # channel_num = []
        # images = []
        # id_labels = []
        # pose_labels = []
    
        # mycase
        Nz = 50
        channel_num = 3
        images = np.load('{}/images.npy'.format(data_place)) # default ./data/images.npy
        id_labels = np.load('{}/ids.npy'.format(data_place))
        pose_labels = np.load('{}/yaws.npy'.format(data_place))
    # 一共有Nd个人,每个人都有Np个角度,通过人和角度的两个one-hot向量可以确定一个有着某个角度的人,这也就是需要提供的额外信息,要告诉生成器,我给你Nd个人,每个人分别都是这些姿势(这些姿势对于每个人是固定的)
        Np = int(pose_labels.max() + 1)# 这个不知道为什么是这么写,不是one-hot么?难道不应该是pose_label的长度吗?
        Nd = int(id_labels.max() + 1) # 发现了,这个不是one-hot。。。这个还没变成one-hot
    
        return [images, id_labels, pose_labels, Nd, Np, Nz, channel_num]
    
    
    if __name__=="__main__":
    
        parser = argparse.ArgumentParser(description='DR_GAN')
        # learning & saving parameterss
        parser.add_argument('-lr', type=float, default=0.0002, help='initial learning rate [default: 0.0002]')
        parser.add_argument('-beta1', type=float, default=0.5, help='adam optimizer parameter [default: 0.5]')
        parser.add_argument('-beta2', type=float, default=0.999, help='adam optimizer parameter [default: 0.999]')
        parser.add_argument('-epochs', type=int, default=1000, help='number of epochs for train [default: 1000]')
        parser.add_argument('-batch-size', type=int, default=8, help='batch size for training [default: 8]')
        parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
        parser.add_argument('-save-freq', type=int, default=1, help='save learned model for every "-save-freq" epoch')
        parser.add_argument('-cuda', action='store_true', default=False, help='enable the gpu')
        # data souce
        parser.add_argument('-random', action='store_true', default=False, help='use randomely created data to run program, instead of from data-place')
        parser.add_argument('-data-place', type=str, default='./data', help='prepared data path to run program')
        # model
        parser.add_argument('-multi-DRGAN', action='store_true', default=False, help='use multi image DR_GAN model')
        parser.add_argument('-images-perID', type=int, default=0, help='number of images per person to input to multi image DR_GAN')
        # option
        parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot(snapshot/{Single or Multiple}/{date}/{epoch}) [default: None]')
        parser.add_argument('-generate', action='store_true', default=None, help='Generate pose modified image from given image')
    
        args = parser.parse_args()
    
        # update args and print
        if args.multi_DRGAN:
            args.save_dir = os.path.join(args.save_dir, 'Multi',datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
        else:
            args.save_dir = os.path.join(args.save_dir, 'Single',datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    
        os.makedirs(args.save_dir)
    
        print("Parameters:")
        for attr, value in sorted(args.__dict__.items()): 
        # __dict__ is the dictionary containing the class's namespace
            text ="\t{}={}\n".format(attr.upper(), value)
            print(text)
            with open('{}/Parameters.txt'.format(args.save_dir),'a') as f:
                f.write(text)
    
    
        # input data
        if args.random:
            images, id_labels, pose_labels, Nd, Np, Nz, channel_num = create_randomdata()
        else:
            print('n\Loading data from [%s]...' % args.data_place)
            try:
                images, id_labels, pose_labels, Nd, Np, Nz, channel_num = DataLoader(args.data_place)
            except:
                print("Sorry, failed to load data")
    
        # model
        if args.snapshot is None:
            if not(args.multi_DRGAN): # 如果是多张图输入的话,要有images-perID, 也就是每个人有多少张图片,这个要说清楚
                D = single_model.Discriminator(Nd, Np, channel_num)
                G = single_model.Generator(Np, Nz, channel_num)
            else:
                if args==0:
                    print("Please specify -images-perID of your data to input to multi_DRGAN")
                    exit()
                else:
                    D = multi_model.Discriminator(Nd, Np, channel_num)
                    G = multi_model.Generator(Np, Nz, channel_num, args.images_perID) # 最后这个id数被送到了生成器
        else:
            print('\nLoading model from [%s]...' % args.snapshot)
            try:
                D = torch.load('{}_D.pt'.format(args.snapshot))
                G = torch.load('{}_G.pt'.format(args.snapshot))
            except:
                print("Sorry, This snapshot doesn't exist.")
                exit()
    
        if not(args.generate):
            if not(args.multi_DRGAN):
                train_single_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D, G, args)
            else:
                if args.batch_size % args.images_perID == 0:
                    train_multiple_DRGAN(images, id_labels, pose_labels, Nd, Np, Nz, D, G, args)
                else:
                    print("Please give valid combination of batch_size, images_perID")
                    exit()
        else:
            # pose_code = [] # specify arbitrary pose code for every image
            pose_code = np.random.uniform(-1,1, (images.shape[0], Np))
            features = Generate_Image(images, pose_code, Nz, G, args)
    

    something like these

    相关文章

      网友评论

          本文标题:python -- argparse

          本文链接:https://www.haomeiwen.com/subject/xxjhixtx.html