美文网首页pytorch
[PyTorch]可以将处理好的数据使用torch.save存储

[PyTorch]可以将处理好的数据使用torch.save存储

作者: VanJordan | 来源:发表于2019-06-21 11:14 被阅读0次
    • 可以将一些需要处理的文本文件处理一次后就使用torch.save(或者pickle)存储成二进制文件方便下一次加载
    def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cache=None, with_labels=False):
        """ Retrieve, tokenize, encode and cache a dataset with optional labels """
        if dataset_cache and os.path.isfile(dataset_cache):
            logger.info("Load encoded dataset from cache at %s", dataset_cache)
            encoded_dataset = torch.load(dataset_cache)
        else:
            # If the dataset is in our list of DATASETS_URL, use this url, otherwise, look for 'train.txt' and 'valid.txt' files
            if dataset_dir in DATASETS_URL:
                dataset_map = DATASETS_URL[dataset_dir]
            else:
                dataset_map = {'train': os.path.join(dataset_dir, 'train.txt'),
                               'valid': os.path.join(dataset_dir, 'valid.txt')}
    
            logger.info("Get dataset from %s", dataset_dir)
            # Download and read dataset and replace a few token for compatibility with the Bert tokenizer we are using
            dataset = {}
            for split_name in dataset_map.keys():
                dataset_file = cached_path(dataset_map[split_name])
                with open(dataset_file, "r", encoding="utf-8") as f:
                    all_lines = f.readlines()
                    dataset[split_name] = [
                            line.strip(' ').replace('<unk>', '[UNK]').replace('\n', '[SEP]' if not with_labels else '')
                            for line in tqdm(all_lines)]
    
            # If we have labels, download and and convert labels in integers
            labels = {}
            if with_labels:
                label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir]
                for split_name in DATASETS_LABELS_URL[dataset_dir]:
                    dataset_file = cached_path(dataset_map['labels'][split_name])
                    with open(dataset_file, "r", encoding="utf-8") as f:
                        all_lines = f.readlines()
                        labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)]
    
            # Tokenize and encode the dataset
            logger.info("Tokenize and encode the dataset")
            logging.getLogger("pytorch_pretrained_bert.tokenization").setLevel(logging.ERROR)  # No warning on sample size
            def encode(obj):
                if isinstance(obj, str):
                    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
                if isinstance(obj, dict):
                    return dict((n, encode(o)) for n, o in obj.items())
                return list(encode(o) for o in tqdm(obj))
            encoded_dataset = encode(dataset)
    
            # Add labels if needed, or if we are doing language modeling, add number of words to get word-level ppl and gather in one list
            for split_name in ['train', 'valid']:
                if with_labels:
                    encoded_dataset[split_name + '_labels'] = labels[split_name]
                else:
                    encoded_dataset[split_name] = [ind for line in encoded_dataset[split_name] for ind in line]
                    encoded_dataset[split_name + '_num_words'] = sum(len(line.split(' ')) for line in dataset[split_name])
    
            # Save to cache
            if dataset_cache:
                logger.info("Save encoded dataset to cache at %s", dataset_cache)
                torch.save(encoded_dataset, dataset_cache)
    
        return encoded_dataset
    

    相关文章

      网友评论

        本文标题:[PyTorch]可以将处理好的数据使用torch.save存储

        本文链接:https://www.haomeiwen.com/subject/fxxyqctx.html