美文网首页
huggingface transfromers基于预训练模型进

huggingface transfromers基于预训练模型进

作者: onmeiei | 来源:发表于2023-08-23 18:51 被阅读0次

    以Resnet为例,找到了一个微软的resnet-50的预训练模型

    from transformers import AutoImageProcessor, ResNetForImageClassification
    
    # 加载前处理处理器,自动构建;功能:输入图片,输出Tenser(1, 3, 224, 224)
    processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
    # 加载与训练模型
    model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
    
    from datasets import load_dataset
    import numpy as np
    
    # 目录结构加载数据集
    """
    目录结构如下:
    flowers102_dir
    ├── train
    │   ├── 0
    │   │   ├── image_06736.jpg
    │   │   └── image_06771.jpg
    │   ├── 1
    │   │   ├── image_05091.jpg
    │   │   └── image_05146.jpg
    │   └── 99
    │       ├── image_07900.jpg
    │       └── image_07941.jpg
    └── validation
        ├── 0
        │   ├── image_06738.jpg
        │   └── image_06773.jpg
        ├── 1
        │   ├── image_05100.jpg
        │   └── image_05138.jpg
        └── 99
            ├── image_07902.jpg
            └── image_07935.jpg
    """
    dataset = load_dataset("imagefolder", data_dir="/test/flowers102_dir")
    
    # 定义转换函数,这个地方遇到了坑。
    # 重要:前处理的输出为(1, 3, 224, 224),但是此处的pixel_values需要的是(3, 224, 224),所以需要reshape或者降维
    def trans(row):
        row["pixel_values"] = np.asarray(processor(row["image"])["pixel_values"][0])
        return row
    
    train_dataset = dataset['train'].map(trans)
    val_dataset = dataset['validation'].map(trans)
    
    from transformers import TrainingArguments, Trainer
    
    # 可以自行根据需要设置训练参数
    training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
    
    # 构建Trainer,参数还有很多,此处为最基本的。
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )
    
    # 开始训练
    trainer.train()
    

    相关文章

      网友评论

          本文标题:huggingface transfromers基于预训练模型进

          本文链接:https://www.haomeiwen.com/subject/gqnpmdtx.html