美文网首页
Pytorch 分类问题

Pytorch 分类问题

作者: guanxidazhuang | 来源:发表于2018-09-11 01:13 被阅读0次

train loss 不断下降,test loss不断下降,说明网络仍在学习;

train loss 不断下降,test loss趋于不变,说明网络过拟合;

train loss 趋于不变,test loss不断下降,说明数据集100%有问题;

train loss 趋于不变,test loss趋于不变,说明学习遇到瓶颈,需要减小学习率或批量数目;

train loss 不断上升,test loss不断上升,说明网络结构设计不当,训练超参数设置不当,数据集经过清洗等问题。

今天跑程序,首先数据集选择BPSK QPSK 8QAM 16QAM,测试集每种调制方式190个,验证集每种调制方式10个

第一次训练

首先用ResNet18,测试得到的结果

Accuracy of  BPSK : 20 %

Accuracy of  QPSK : 10 %

Accuracy of  8QAM : 70 %

Accuracy of 16QAM : 20 %

分析,有可能的原因是数据集是存在频率为0的时候的图像,也许频率为零初的噪声干扰了效果

第二次训练

生成了一个VGG11模型,保存在VGG11_Cyclic_BPSK_QPSK_8QAM_16QAM.pkl,结果为

Accuracy of  BPSK : 40 %

Accuracy of  QPSK : 40 %

Accuracy of  8QAM : 30 %

Accuracy of 16QAM : 30 %

平均的正确率为35%

这个结果还可以提升,提升的方式有:

1. 增加iteration次数

2. 改变学习速率

3. 增加VGGNet的层数

首先看第一种情况,让iteration有10次,那么训练阶段的loss function最后是0.013,此时得到的模型为VGG11_Cyclic_BPSK_QPSK_8QAM_16QAM.pkl,得到的结果

Accuracy of  BPSK : 50 %

Accuracy of  QPSK : 50 %

Accuracy of  8QAM : 10 %

Accuracy of 16QAM :  0 %

平均正确率为27%,降了,

batch size变为1,那么正确率

Accuracy of  BPSK : 20 %

Accuracy of  QPSK : 40 %

Accuracy of  8QAM :  0 %

Accuracy of 16QAM : 10 %

那么是不是可以在学习的过程中降低学习速率。

需要在不同的loss function的情况下,有不同的learning rate,这是因为

接下来在训练数据的时候,发现loss function的值在不断降低,一直能降到小于0.01,直到第500个数据,loss function又会升到2以上。这是由于每一类的数据一共有500个,这说明数据不打乱数据的话,分类器的loss function会受影响。模型保存为VGG11_Cyclic_BPSK_QPSK_8QAM_16QAM_3.pkl

[10, 950] loss: 0.048

('Learning rate is', 1e-06)

[10,  1000] loss: 1.798

('Learning rate is', 1e-05)

[10,  1050] loss: 2.041

('Learning rate is', 1e-05)

[10,  1100] loss: 1.369

('Learning rate is', 1e-05)

[10,  1150] loss: 1.073

('Learning rate is', 1e-05)

[10,  1200] loss: 0.836

('Learning rate is', 1e-05)

[10,  1250] loss: 0.572

('Learning rate is', 1e-06)

[10,  1300] loss: 0.315

('Learning rate is', 1e-06)

[10,  1350] loss: 0.154

('Learning rate is', 1e-06)

[10,  1400] loss: 0.086

('Learning rate is', 1e-06)

[10,  1450] loss: 0.056

('Learning rate is', 1e-06)

[10,  1500] loss: 2.172

('Learning rate is', 1e-05)

[10,  1550] loss: 1.758

('Learning rate is', 1e-05)

[10,  1600] loss: 1.339

('Learning rate is', 1e-05)

[10,  1650] loss: 1.118

('Learning rate is', 1e-05)

[10,  1700] loss: 0.873

('Learning rate is', 1e-05)

[10,  1750] loss: 0.629

('Learning rate is', 1e-06)

[10,  1800] loss: 0.362

('Learning rate is', 1e-06)

[10,  1850] loss: 0.179

('Learning rate is', 1e-06)

[10,  1900] loss: 0.100

('Learning rate is', 1e-06)

[10,  1950] loss: 0.058

('Learning rate is', 1e-06)

所以接下来我打乱数据来看,打乱数据的方法是

trainloader = torch.utils.data.DataLoader(train_loader, batch_size=define_batch_size, shuffle=False)

里,shuffle=True

现在出现的情况是,同一种调制方式的图片刚刚训练好,或者过拟合了,又被新的其他类别的信号给破坏了。

梯度下降,红色的箭头示意超出最优点的梯度

第三次训练

shuffle设置为True以后,发现training loss趋于不变

[1, 50] loss: 1.392

('Learning rate is', 1e-05)

[1,  100] loss: 1.412

('Learning rate is', 1e-05)

[1,  150] loss: 1.412

('Learning rate is', 1e-05)

[1,  200] loss: 1.409

('Learning rate is', 1e-05)

[1,  250] loss: 1.391

('Learning rate is', 1e-05)

[1,  300] loss: 1.413

('Learning rate is', 1e-05)

[1,  350] loss: 1.395

('Learning rate is', 1e-05)

[1,  400] loss: 1.424

('Learning rate is', 1e-05)

[1,  450] loss: 1.406

('Learning rate is', 1e-05)

[1,  500] loss: 1.371

('Learning rate is', 1e-05)

[1,  550] loss: 1.391

('Learning rate is', 1e-05)

[1,  600] loss: 1.411

('Learning rate is', 1e-05)

[1,  650] loss: 1.409

('Learning rate is', 1e-05)

[1,  700] loss: 1.377

('Learning rate is', 1e-05)

[1,  750] loss: 1.396

('Learning rate is', 1e-05)

[1,  800] loss: 1.403

('Learning rate is', 1e-05)

[1,  850] loss: 1.402

('Learning rate is', 1e-05)

[1,  900] loss: 1.387

('Learning rate is', 1e-05)

[1,  950] loss: 1.389

('Learning rate is', 1e-05)

[1,  1000] loss: 1.394

('Learning rate is', 1e-05)

[1,  1050] loss: 1.385

('Learning rate is', 1e-05)

[1,  1100] loss: 1.349

('Learning rate is', 1e-05)

[1,  1150] loss: 1.384

('Learning rate is', 1e-05)

[1,  1200] loss: 1.393

('Learning rate is', 1e-05)

[1,  1250] loss: 1.394

('Learning rate is', 1e-05)

[1,  1300] loss: 1.385

('Learning rate is', 1e-05)

[1,  1350] loss: 1.386

('Learning rate is', 1e-05)

[1,  1400] loss: 1.388

('Learning rate is', 1e-05)

[1,  1450] loss: 1.387

('Learning rate is', 1e-05)

[1,  1500] loss: 1.375

('Learning rate is', 1e-05)

[1,  1550] loss: 1.392

('Learning rate is', 1e-05)

[1,  1600] loss: 1.411

('Learning rate is', 1e-05)

[1,  1650] loss: 1.389

('Learning rate is', 1e-05)

[1,  1700] loss: 1.388

('Learning rate is', 1e-05)

[1,  1750] loss: 1.385

('Learning rate is', 1e-05)

[1,  1800] loss: 1.400

('Learning rate is', 1e-05)

[1,  1850] loss: 1.390

('Learning rate is', 1e-05)

[1,  1900] loss: 1.408

('Learning rate is', 1e-05)

[1,  1950] loss: 1.375

('Learning rate is', 1e-05)

看来learning rate调小一些

这一次learning rate调整的时候是这样,首先在loss高于0.8的时候,learning rate是1e-7,loss介于0.8到0.02之间的时候是1e-8,但是这次loss降到0.9到1之间的时候就降不下去了,这说明loss在小于1的时候,学习速率还要再降一降。

这次得到的模型保存为

VGG11_Cyclic_BPSK_QPSK_8QAM_16QAM_4.pkl

Accuracy of BPSK : 40 %

Accuracy of  QPSK : 30 %

Accuracy of  8QAM : 10 %

Accuracy of 16QAM : 40 %

第四次训练

这次选择的是ResNet50,设定是:

loss>1, learning rate=1e-7

0.7<loss<1, learning rate=1e-9

0.1<loss<0.7, learning rate=1e-12

loss<0.1, learning rate=1e-14

这次loss function开始在0.5左右徘徊,比VGGNet效果好,既然loss function在0.5左右徘徊,

[10, 1450] loss: 0.481

('Learning rate is', 1e-12)

[10,  1500] loss: 0.541

('Learning rate is', 1e-12)

[10,  1550] loss: 0.501

('Learning rate is', 1e-12)

[10,  1600] loss: 0.576

('Learning rate is', 1e-12)

[10,  1650] loss: 0.543

('Learning rate is', 1e-12)

[10,  1700] loss: 0.473

('Learning rate is', 1e-12)

[10,  1750] loss: 0.562

('Learning rate is', 1e-12)

[10,  1800] loss: 0.579

('Learning rate is', 1e-12)

[10,  1850] loss: 0.504

('Learning rate is', 1e-12)

[10,  1900] loss: 0.490

('Learning rate is', 1e-12)

[10,  1950] loss: 0.451

('Learning rate is', 1e-12)

那么就要在0.5左右再设置一个门限,当loss 低于0.5的时候,learning rate还要再降,而且0.7有点低,可以设定为0.8,保存的模型是

现在的设定是

######################################################################

# Define the learning rate

change_learning_rate_mark1 = 2.2

learning_rate1 = 1e-4

change_learning_rate_mark2 = 2.0

learning_rate2 = 9e-5

change_learning_rate_mark3 = 1.8

learning_rate3 = 8e-5

change_learning_rate_mark4 = 1.6

learning_rate4 = 7e-5

change_learning_rate_mark5 = 1.4

learning_rate5 = 6e-5

change_learning_rate_mark6 = 1.2

learning_rate6 = 5e-5

change_learning_rate_mark7 = 1

learning_rate7 = 4e-5

change_learning_rate_mark8 = 0.8

learning_rate8 = 3e-5

change_learning_rate_mark9 = 0.6

learning_rate9 = 2e-5

change_learning_rate_mark10 = 0.4

learning_rate10 = 1e-5

change_learning_rate_mark11 = 0.2

learning_rate11 = 9e-6

change_learning_rate_mark12 = 0.08

learning_rate12 = 8e-6

learning_rate13 = 7e-6

stop_loss_function = 0.001

# ==========================================================

最后loss function到了0.3左右

('Learning rate is', 9e-06)

[10,  1600] loss: 0.374

('Learning rate is', 9e-06)

[10,  1650] loss: 0.451

('Learning rate is', 1e-05)

[10,  1700] loss: 0.478

('Learning rate is', 1e-05)

[10,  1750] loss: 0.417

('Learning rate is', 1e-05)

[10,  1800] loss: 0.344

('Learning rate is', 9e-06)

[10,  1850] loss: 0.301

('Learning rate is', 9e-06)

[10,  1900] loss: 0.354

('Learning rate is', 9e-06)

[10,  1950] loss: 0.317

('Learning rate is', 9e-06)

Finished Training

得到结果

Accuracy of BPSK : 47 %

Accuracy of  QPSK : 38 %

Accuracy of  8QAM : 42 %

Accuracy of 16QAM : 52 %

第五次训练

接下来试一试ResNet152,这次的learning rate按照x^{-2}来递减

######################################################################

# Define the learning rate

change_learning_rate_mark1 =2.2

learning_rate1 =1e-3

change_learning_rate_mark2 =2.0

learning_rate2 =5e-4

change_learning_rate_mark3 =1.8

learning_rate3 =1.1e-4

change_learning_rate_mark4 =1.6

learning_rate4 =0.0625e-5

change_learning_rate_mark5 =1.4

learning_rate5 =0.04e-5

change_learning_rate_mark6 =1.2

learning_rate6 =0.0278e-5

change_learning_rate_mark7 =1

learning_rate7 =0.0204e-5

change_learning_rate_mark8 =0.8

learning_rate8 =0.0156e-5

change_learning_rate_mark9 =0.6

learning_rate9 =0.0123e-5

change_learning_rate_mark10 =0.4

learning_rate10 =1e-7

change_learning_rate_mark11 =0.2

learning_rate11 =0.0083e-6

change_learning_rate_mark12 =0.08

learning_rate12 =0.0069e-6

learning_rate13 =0.0059e-6

stop_loss_function =0.001

# ==========================================================

在0.8的地方还是降不下去了,所以loss改成x^{-4}来递减,得到的结果

[10, 1450] loss: 1.561

('Learning rate is', 8.4999e-08)

[10,  1500] loss: 1.543

('Learning rate is', 8.4999e-08)

[10,  1550] loss: 1.589

('Learning rate is', 8.4999e-08)

[10,  1600] loss: 1.472

('Learning rate is', 3.8147e-08)

[10,  1650] loss: 1.574

('Learning rate is', 8.4999e-08)

[10,  1700] loss: 1.755

('Learning rate is', 6.4e-07)

[10,  1750] loss: 1.410

('Learning rate is', 3.8147e-08)

[10,  1800] loss: 1.533

('Learning rate is', 8.4999e-08)

[10,  1850] loss: 1.497

('Learning rate is', 3.8147e-08)

[10,  1900] loss: 1.544

('Learning rate is', 8.4999e-08)

[10,  1950] loss: 1.472

('Learning rate is', 3.8147e-08)

Finished Training

模型保存为ResNet50_Cyclic_BPSK_QPSK_8QAM_16QAM3.pkl

第六次训练,这次选择的是ResNet152,设定为

######################################################################

# Define the learning rate

change_learning_rate_mark1 =2

learning_rate1 =1e-2

change_learning_rate_mark2 =1.5

learning_rate2 =1.5625e-3

change_learning_rate_mark3 =1.0

learning_rate3 =1.3717e-4

change_learning_rate_mark4 =0.5

learning_rate4 =2.4414e-5

change_learning_rate_mark5 =0.1

learning_rate5 =6.4e-6

change_learning_rate_mark6 =0.05

learning_rate6 =2.1433e-6

change_learning_rate_mark7 =0.01

learning_rate7 =8.4999e-7

change_learning_rate_mark8 =0.005

learning_rate8 =3.8147e-7

change_learning_rate_mark9 =0.001

learning_rate9 =1.8817e-7

change_learning_rate_mark10 =0.0005

learning_rate10 =1e-7

change_learning_rate_mark11 =0.0001

learning_rate11 =5.6447e-8

change_learning_rate_mark12 =0.00005

learning_rate12 =3.34909e-8

learning_rate13 =2.0718e-8

stop_loss_function =0.001

# ==========================================================

相关文章

网友评论

      本文标题:Pytorch 分类问题

      本文链接:https://www.haomeiwen.com/subject/kvzsgftx.html