美文网首页
Pytorch-Lighting 记录

Pytorch-Lighting 记录

作者: IT_小马哥 | 来源:发表于2022-05-23 21:59 被阅读0次

    自己的函数

    class myNet(pl.LightningModule):
        def __init__(self, args, class_num, max_length):
            super().__init__()
            self.save_hyperparameters()
            self.args = args
            self.class_num = class_num
            self.max_length = max_length
    
            config = AutoConfig.from_pretrained(args.model_path)
            config.output_hidden_states = True
            self.bert = AutoModel.from_pretrained(args.model_path, config=config)
            self.tokenizer = AutoTokenizer.from_pretrained(args.model_path)
            self.fc = nn.Linear(self.bert.config.hidden_size, class_num)
            self.loss = nn.CrossEntropyLoss()
    
        def collate(self, batchdata):
            batchtext = [one[0] for one in batchdata]
            batchlabel = [one[1] for one in batchdata]
            tokens = self.tokenizer(batchtext, return_tensors="pt", padding=True,
                                    max_length=self.max_length, truncation=True)
            batchlabel = torch.tensor(batchlabel)
            return tokens, batchlabel
    
        def configure_optimizers(self):
    
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay)],
                 'weight_decay': args.weight_decay, 'lr': self.args.lr},
                {'params': [p for n, p in self.bert.named_parameters() if any(nd in n for nd in no_decay)],
                 'weight_decay': 0.0, 'lr': self.args.lr},
                {'params': self.gate.parameters(), 'lr': self.args.lr},
            ]
            optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.lr)
    
            total_steps = len(self.train_dataloader()) * args.epochs
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=args.warmup_steps,
                                                        num_training_steps=total_steps)
            scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
            return [optimizer], [scheduler]
    
        def forward(self, input_ids, mask,):
            pass
    
        def training_step(self, batch, batch_idx):
            input_ids = batch[0]['input_ids']
            mask = batch[0]['attention_mask']
            labels = batch[1]
            logits = self(input_ids, mask)
            loss= self.loss(logits, labels)
    
            self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            return {'loss': loss}
    
        def training_step_end(self, training_step_outputs):  # 一个batch结束
            # training_step_outputs 是 training_step 返回的东西
            pass
    
        def training_epoch_end(self, training_step_outputs):  # 一个epoch
            #  training_step_outputs 是一个列表, 包含每个step的返回
            pass
    
        def share_val_step(self, batch, batch_idx):
            input_ids = batch[0]['input_ids']
            mask = batch[0]['attention_mask']
            labels = batch[1]
    
            out = self(input_ids, mask)
            loss_ce = self.loss(out, labels)
            y_pred_label = out.argmax(dim=1)
            testf1, testp, testr, testacc = self.envaulation(labels, y_pred_label)
            return loss_ce, testf1, testp, testr, testacc
    
        def validation_step(self, batch, batch_idx):
            loss_ce, testf1, testp, testr, testacc = self.share_val_step(batch, batch_idx)
            return {'val_loss_step': loss_ce, 'val_acc_step': testacc}
    
        def validation_step_end(self, val_step_outputs):
            pass
            # # predictions from each GPU
            # predictions = val_step_outputs["pred"]
            # # losses from each GPU
            # losses = val_step_outputs["loss"]
            #
            # gpu_0_prediction = predictions[0]
            # gpu_1_prediction = predictions[1]
            #
            # # do something with both outputs
            # return (losses[0] + losses[1]) / 2
    
        def validation_epoch_end(self, validation_step_outputs):
            val_loss = torch.stack([x["val_loss_step"] for x in validation_step_outputs]).mean()
            val_acc = torch.stack([x["val_acc_step"] for x in validation_step_outputs]).mean()
            self.log('val_acc', val_acc)
            self.print(val_acc)
            return {"val_loss": val_loss, "val_acc": val_acc}
    
        def test_step(self, batch, batch_idx):
            loss_ce, testf1, testp, testr, testacc = self.share_val_step(batch, batch_idx)
            metrics = {'test_loss_step': loss_ce, 'test_acc_step': testacc}
            self.log_dict(metrics)
            return metrics
    
        def test_step_end(self, output_results):
            pass
    
        def test_epoch_end(self, test_step_outputs):
            # do something with the outputs of all test batches
            test_loss = torch.stack([x["test_loss_step"] for x in test_step_outputs]).mean()
            test_acc = torch.stack([x["test_acc_step"] for x in test_step_outputs]).mean()
            metrics = {'test_loss': test_loss, 'test_acc': test_acc}  # 这些在控制台输出
            self.log_dict(metrics)
    
        def train_dataloader(self):
            train_data = JSON_Dataset('{}_train.json'.format(self.args.dataset))
            self.print('训练集的数量为:{}'.format(len(train_data)))
            train_loader = DataLoader(train_data, batch_size=self.args.batch_size, num_workers=2, collate_fn=self.collate)
            return train_loader
    
        def val_dataloader(self):
            valid_data = JSON_Dataset('{}_dev.json'.format(self.args.dataset))
            self.print('验证集的数量为:{}'.format(len(valid_data)))
            valid_loader = DataLoader(valid_data, batch_size=self.args.batch_size, num_workers=2, collate_fn=self.collate)
            return valid_loader
    
        def test_dataloader(self):
            test_data = JSON_Dataset('{}_test.json'.format(self.args.dataset))
            self.print('测试集的数量为:{}'.format(len(test_data)))
            test_loader = DataLoader(test_data, batch_size=self.args.batch_size, num_workers=2, collate_fn=self.collate)
            return test_loader
    
        def envaulation(self, y_true, y_pred):  # 用于评估多分类
            y_true = y_true.to('cpu').tolist()
            y_pred = y_pred.to('cpu').tolist()
            f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
            p = precision_score(y_true, y_pred, average='macro', zero_division=0)
            r = recall_score(y_true, y_pred, average='macro', zero_division=0)
            acc = accuracy_score(y_true, y_pred)
            return torch.tensor(f1, device=self.device), torch.tensor(p, device=self.device), \
                   torch.tensor(r, device=self.device), torch.tensor(acc, device=self.device),
    
        def configure_callbacks(self):
            checkpoint_callback = ModelCheckpoint(monitor='val_acc',
                                                  dirpath='{}/{}'.format(self.args.log_dir, self.args.dataset),
                                                  filename='best',  # filename='best-{epoch:02d}-{val_acc:.3f}',
                                                  save_top_k=1,
                                                  mode='max',
                                                  save_last=False)
    
            early_stop_callback = EarlyStopping(monitor="val_acc", min_delta=0.00, patience=10, verbose=False, mode="max")
            return [checkpoint_callback, early_stop_callback]
    
        def on_train_start(self):
            self.print("Training is started!")
    
        def on_train_end(self):
            self.print("Training is done.")
    

    主函数

    -官方建议, 测试时用 pl.Trainer(gpus=1), 个人测试了原始的 trainer.test()和其比较,结果差不多

    def main(args):
        # 拆解为字典
        # dict_args = vars(args)
        # print(dict_args)
    
        warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
        pl.seed_everything(args.seed)
        dataset_length_and_class = {
            "AGNews": {'max_length': 100, 'class_num': 4},
            "SST1": {'max_length': 50, 'class_num': 5},
            "SST2": {'max_length': 50, 'class_num': 2},
            "Subj": {'max_length': 50, 'class_num': 2},
            # "Trec": {'max_length': 50, 'class_num': 6},
            "IMDB": {'max_length': 256, 'class_num': 2},
            "Yelp": {'max_length': 150, 'class_num': 2},
            "RT": {'max_length': 100, 'class_num': 2},
            "COLA": {'max_length': 100, 'class_num': 2},
            # "STSB": {'max_length': 100, 'class_num': 5},
            "QNLI": {'max_length': 100, 'class_num': 2},
            "QQP": {'max_length': 100, 'class_num': 2},
            # "RTE": {'max_length': 100, 'class_num': 2},
            # "WNLI": {'max_length': 100, 'class_num': 2}
        }
    
        if args.predict:
            # 只需要加载保存好的路径一切就OK
            test_model = myNet.load_from_checkpoint(checkpoint_path='{}/{}/best.ckpt'.format(args.log_dir, args.dataset))
            trainer = pl.Trainer(gpus=1, precision=16)
            trainer.test(model=test_model)
        else:
            args.model_path = r'/home/LAB/magh/pytorch_study/bert-base-uncased' if args.model_name == 'bert' \
                else r'/home/LAB/magh/pytorch_study/Roberta-large'
    
            max_length = dataset_length_and_class[args.dataset]['max_length']
            class_num = dataset_length_and_class[args.dataset]['class_num']
    
            model = myNet(args=args, class_num=class_num, max_length=max_length)
            tb_logger = pl_loggers.TensorBoardLogger(save_dir="{}/{}".format(args.log_dir, args.dataset),
                                                     )
    
            # real_batch = batch_size * num_nodes * num_gpus
            # word_size = GPUS * num_nodes
            trainer = Trainer(
                strategy="ddp", devices="auto", accelerator="auto",
                logger=tb_logger, log_every_n_steps=1,
                # 如果中断将重新从这里加载继续训练
                flush_logs_every_n_steps=10,
                max_epochs=args.epochs)
            trainer.fit(model)
           trainer.test()
    

    参数

    if __name__ == "__main__":
        parser = ArgumentParser()
    
        parser.add_argument('--batch_size', default=128, type=int)
        parser.add_argument('--num_workers', default=4, type=int)
        parser.add_argument('--seed', default=2022, type=int)
        parser.add_argument('--lr', default=5e-5, type=float)
        parser.add_argument('--t', default=0.5, type=float)
        parser.add_argument('--epochs', default=5, type=int)
        parser.add_argument('--load_best', action='store_true')
        parser.add_argument('--weight_decay', default=5e-4, type=float)
        parser.add_argument('--warmup_steps', default=5, type=float)
    
        parser.add_argument('--dataset', default='SST1', choices=['IMDB', 'SST1', 'SST2', 'Yelp',
                                                                  'AGNews', 'Trec', 'Subj', 'RT',
                                                                  'COLA', 'STSB', 'QNLI',
                                                                  'QQP', 'RTE', 'WNLI'], type=str)
        parser.add_argument('--data_dir', default='.', type=str)
        parser.add_argument('--model_name', default='bert', type=str)
        parser.add_argument('--log_dir', default='logs', type=str)
        # 这里直接预测
        parser.add_argument('--predict', action='store_true')
        # parser = Trainer.add_argparse_args(parser)
        args = parser.parse_args()
        # tensorboard - -logdir = lightning_logs /
        main(args)
    
    
    

    Slrum提交, 只需要申请多个GPU就行

    #!/bin/bash
    # SLURM SUBMIT SCRIPT
    ###SBATCH --mail-type=ALL
    ###SBATCH --mail-user=XXX@XXX
    #SBATCH --job-name=myjob
    ######### 申请两个GPU ####################
    #SBATCH --gres=gpu:2
    
    ######### 激活虚拟环境 ####################
    source /home/LAB/anaconda3/etc/profile.d/conda.sh
    conda activate torchtext
    ######### 找文件运行####################
    # shellcheck disable=SC2164
    cd  你的文件路径
    python train.py --dataset Subj --epochs 50
    

    错误记录

    • 再创建LightningDataModule的时候,Dataset的初始化必须放在setup的下面,否则GPU加载出现 错误。
    • 在load_from_checkpoint时, 缺少属性是因为没有设置self.save_hyperparameters(), 这时可以手动传入参数

    相关文章

      网友评论

          本文标题:Pytorch-Lighting 记录

          本文链接:https://www.haomeiwen.com/subject/absnqltx.html