3.4 New model
3.4.1 Save model
Pytorch provide two kinds of method to save model. We recommmend the method which only saves parameters. Because it's more feasible and dont' rely on fixed model.
When saving parameters, we not only save learnable parameters in model, but also learnable parameters in optimizer.
A common PyTorch convention is to save models using either a .pt or .pth file extension.
Read more abount save load from this link.
# show parameters in model
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("\nOptimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Model's state_dict:
hidden1.weight torch.Size([100, 784])
hidden1.bias torch.Size([100])
hidden2.weight torch.Size([100, 100])
hidden2.bias torch.Size([100])
hidden3.weight torch.Size([100, 100])
hidden3.bias torch.Size([100])
classification_layer.weight torch.Size([10, 100])
classification_layer.bias torch.Size([10])
hidden1_bn.weight torch.Size([100])
hidden1_bn.bias torch.Size([100])
hidden1_bn.running_mean torch.Size([100])
hidden1_bn.running_var torch.Size([100])
hidden1_bn.num_batches_tracked torch.Size([])
hidden2_bn.weight torch.Size([100])
hidden2_bn.bias torch.Size([100])
hidden2_bn.running_mean torch.Size([100])
hidden2_bn.running_var torch.Size([100])
hidden2_bn.num_batches_tracked torch.Size([])
hidden3_bn.weight torch.Size([100])
hidden3_bn.bias torch.Size([100])
hidden3_bn.running_mean torch.Size([100])
hidden3_bn.running_var torch.Size([100])
hidden3_bn.num_batches_tracked torch.Size([])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.75, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4755824576, 4755820904, 4750998264, 4757925536, 4757922584, 4758702408, 4758703200, 4758702552, 4758702480, 4758702264, 4758703704, 4758702912, 4764186232, 4764188032]}]
# save model
save_path = './model.pt'
torch.save(model.state_dict(), save_path)
# load parameters from files
saved_parametes = torch.load(save_path)
print(saved_parametes)
OrderedDict([('hidden1.weight', tensor([[ 0.0061, 0.0296, -0.0111, ..., 0.0030, -0.0219, -0.0101],
[-0.0171, 0.0213, 0.0470, ..., 0.0168, -0.0097, -0.0076],
[-0.0094, 0.0342, 0.0366, ..., 0.0347, 0.0201, -0.0014],
...,
[ 0.0357, 0.0599, 0.0044, ..., 0.0245, 0.0249, 0.0117],
[ 0.0388, -0.0259, 0.0334, ..., 0.0303, 0.0065, -0.0191],
[ 0.0564, 0.0475, 0.0173, ..., 0.0403, 0.0442, 0.0449]])), ('hidden1.bias', tensor([-0.0168, -0.0027, -0.0294, -0.0164, 0.0031, -0.1126, -0.1200, -0.0309,
0.0018, -0.0125, -0.0191, -0.0128, -0.0523, -0.0306, 0.0244, -0.0634,
-0.0119, -0.0476, -0.1635, -0.0615, 0.0005, -0.0329, -0.0547, -0.0155,
-0.0197, -0.0935, -0.0182, -0.1492, 0.0312, -0.0513, -0.1478, -0.0836,
0.0351, -0.0060, 0.0264, 0.0090, -0.0292, -0.0760, -0.0030, -0.0301,
-0.0226, -0.1158, -0.0211, -0.0105, -0.1547, -0.1294, -0.0352, -0.0362,
-0.0490, -0.0284, -0.0899, -0.0111, 0.0088, 0.0089, -0.1379, -0.0392,
0.0047, -0.0556, -0.1105, -0.0871, -0.0625, -0.0557, -0.0433, -0.0270,
-0.0180, 0.0207, -0.0378, -0.0158, -0.1503, -0.0545, -0.0462, -0.0816,
0.0008, -0.0367, -0.0082, -0.0644, -0.0191, -0.0992, -0.0545, -0.0881,
-0.1154, -0.0954, -0.0931, -0.0208, -0.1681, -0.0307, 0.0138, -0.0588,
-0.0424, -0.0218, -0.0310, -0.0141, -0.0217, -0.0678, -0.1139, 0.0142,
-0.0263, -0.0896, -0.0440, -0.0806])), ('hidden2.weight', tensor([[ 0.0156, 0.0259, 0.0132, ..., 0.0317, 0.0130, -0.0083],
[-0.0703, 0.0066, 0.0261, ..., -0.1618, -0.1010, -0.0783],
[-0.0013, 0.0448, -0.0532, ..., -0.0807, 0.0350, 0.0551],
...,
[-0.0748, -0.0055, -0.0958, ..., -0.0372, 0.0271, -0.1036],
[ 0.0920, 0.1272, 0.0763, ..., -0.0787, 0.0597, -0.1064],
[-0.0779, 0.0371, 0.0344, ..., -0.0633, 0.0402, -0.0065]])), ('hidden2.bias', tensor([-0.5761, 0.5198, 0.3693, -0.1639, -0.1722, -0.4134, 1.0224, 0.0591,
-0.1358, 0.0150, -0.1590, -0.2059, -0.0574, 0.3346, -0.1240, -0.0494,
-0.0782, -0.0758, 0.2674, -0.0309, -0.2096, -0.3061, -0.1266, -0.2250,
-0.0352, -0.3626, -0.3968, -0.1523, -0.1501, 0.0105, -0.1572, 0.4409,
-0.0585, -0.1668, 0.0431, -0.3306, -0.2386, -0.4994, -0.0402, 0.2434,
-0.0695, 0.4839, -0.0635, -0.3354, -0.2052, 0.1460, -0.3221, -0.4942,
-0.4669, -0.1758, -0.2361, 0.0703, -0.0994, -0.3179, -0.0522, -0.3119,
0.4844, 1.0562, -0.2837, -0.2965, -0.1459, -0.1997, -0.5648, -0.0028,
-0.2376, -0.1025, -0.0931, -0.1769, 0.0466, -0.0933, -0.1596, -0.3318,
-0.2438, 0.0077, -0.1148, -0.0701, -0.2182, 0.0352, -0.1677, -0.2224,
-0.1809, 0.0568, -0.0896, -0.0801, -0.2565, -0.4778, -0.1549, -0.0518,
-0.5629, -0.0945, 0.8213, -0.0217, -0.0893, -0.3187, -0.2347, 0.4022,
-0.3037, -0.0043, -0.0388, 0.0045])), ('hidden3.weight', tensor([[-0.1371, -0.1332, 0.0756, ..., -0.1936, -0.1040, -0.0236],
[ 0.0034, 0.0138, -0.0925, ..., -0.0231, -0.1404, -0.0059],
[-0.0852, 0.0128, -0.0367, ..., 0.2121, -0.1505, -0.0288],
...,
[-0.1071, 0.0453, -0.0177, ..., -0.0548, -0.0398, 0.1109],
[-0.0492, 0.0867, 0.3073, ..., -0.0626, 0.1075, 0.2109],
[-0.1140, -0.0369, -0.0115, ..., -0.0396, -0.0358, -0.0073]])), ('hidden3.bias', tensor([-8.4660e-02, -3.2333e-02, -6.0429e-02, -1.2267e-01, -1.1553e-01,
-3.6592e-02, -9.9289e-02, 6.3957e-01, -2.0471e-01, -1.2567e-01,
-2.4764e-02, -1.0635e-01, -2.6803e-02, -8.6840e-02, -2.4284e-01,
-1.1553e-01, 1.1392e-03, -1.0988e-01, -9.8350e-02, -2.0178e-02,
-1.0630e-01, -8.7644e-02, -6.7755e-02, -1.5455e-01, -8.0500e-02,
2.2053e-01, 5.6742e-02, -1.4824e-01, -3.3071e-02, 3.2688e-02,
6.2942e-01, -4.6284e-02, 2.1287e-01, -3.4355e-02, -1.2961e-01,
2.9527e-01, 1.2094e-03, -3.3945e-02, -2.1949e-01, -7.0505e-02,
-9.2214e-02, -1.1195e-01, 3.7178e-01, -2.5034e-02, -1.8616e-01,
-1.0701e-01, -6.5656e-02, -6.3755e-02, 8.5521e-01, -1.4393e-01,
-1.8443e-01, 1.7599e-02, 4.3720e-01, -1.0936e-01, 1.0006e-01,
-8.8871e-02, 5.2978e-01, -1.1293e-01, -1.1250e-01, -2.5872e-01,
-2.4333e-01, -7.4563e-02, -1.1477e-01, 9.9877e-02, -1.2331e-01,
-1.0594e-01, -2.8752e-02, -3.4128e-02, -2.5374e-01, -8.5538e-02,
-8.6164e-02, 4.9599e-01, 6.3113e-01, -4.9306e-02, 8.3178e-02,
9.6917e-01, 1.7951e+00, -1.9829e-01, -1.7462e-01, 1.0686e-01,
2.3232e-02, -1.1916e-01, -1.2637e-01, 1.2163e+00, -5.2430e-02,
-1.2705e-01, -1.1642e-01, -1.4296e-01, -7.0017e-02, 3.6222e-01,
-1.9231e-01, -9.3500e-02, -6.6554e-02, 7.4068e-02, -1.1235e-01,
-1.0035e-01, 2.5663e-01, -6.0805e-02, 1.2717e+00, -8.1130e-02])), ('classification_layer.weight', tensor([[-6.1187e-01, -5.6507e-01, -1.4424e-03, -2.7045e-01, 1.5389e-01,
-1.4199e-01, -1.3265e-01, -1.2181e-01, -5.0219e-01, -7.1377e-02,
1.0560e-02, -3.2474e-01, -1.6185e-01, -5.3878e-02, 2.9584e-01,
1.2464e-03, -1.1910e-01, -1.5456e-01, -4.6994e-01, 7.8584e-02,
-5.3734e-01, 6.5176e-01, -7.4570e-03, 7.1858e-02, -9.0464e-02,
-5.4486e-02, -4.3265e-01, -4.6849e-02, -1.6478e-01, -6.6419e-01,
-1.5395e-01, -7.8686e-02, -7.1704e-02, -3.8201e-02, 1.1336e-03,
3.0307e-01, -6.6520e-02, -4.9982e-02, -1.5092e-01, 3.2128e-02,
-3.9149e-01, 3.1262e-02, 3.2770e-02, 1.7711e-02, 1.5304e-01,
-1.3411e-01, 2.5674e-02, -1.7345e-02, 3.5925e-01, -3.7818e-01,
-2.2275e-01, -4.9380e-01, -1.3756e-01, -2.8159e-01, -1.1654e-01,
-2.2355e-02, -5.9519e-01, 1.6007e-02, 1.5933e-01, -1.2804e-01,
-2.1505e-01, -6.4397e-02, -3.2399e-01, -5.6055e-02, -5.0692e-01,
-2.1875e-01, -8.4137e-02, -1.7504e-01, -1.1924e-01, 5.5566e-02,
-3.4110e-01, 8.7355e-03, -5.7918e-03, -6.6834e-02, -1.4117e-01,
-5.4462e-01, 2.7181e-01, 6.9094e-02, -5.3700e-02, -1.1022e-01,
1.0807e-02, -1.8002e-01, -2.0719e-02, 1.1164e-01, 2.1247e-02,
-8.1494e-01, -2.8763e-01, -3.5509e-01, 4.1251e-02, 3.2906e-01,
1.0091e-01, -1.9347e-01, -1.4978e-01, 8.7678e-02, 2.4100e-02,
-1.8897e-01, -4.3147e-01, 5.3150e-03, -1.3036e-01, 2.1785e-02],
[-1.2737e-01, -6.2360e-02, -2.6765e-01, -1.5767e-01, -2.5594e-02,
-6.6481e-02, 1.7943e-01, 4.4155e-02, -2.4315e-01, -4.1117e-02,
-1.2292e-01, 2.7662e-02, -2.0391e-01, 2.3087e-01, -7.2216e-02,
6.3339e-02, -6.5326e-02, -1.2291e-01, -6.6390e-02, 1.8075e-01,
-1.3624e-01, -3.4863e-01, -4.5377e-02, -2.3763e-01, -1.3728e-01,
-1.0981e-01, 8.0206e-01, -2.9498e-02, 1.0278e-01, 1.7782e-01,
3.7626e-02, -4.1234e-02, 3.0991e-02, -1.1380e-01, 5.3500e-02,
4.3036e-02, 2.1550e-01, -1.3913e-03, 3.5874e-02, -9.0758e-02,
-6.9365e-03, -3.5689e-02, -1.4543e-01, -4.2391e-02, 1.0947e-01,
-2.7372e-02, -2.8920e-02, 1.0706e-02, -5.6517e-02, -1.0215e-01,
-1.1967e-01, -8.3464e-02, 5.2941e-01, -1.6522e-01, 1.5481e-01,
1.5158e-03, -2.7996e-01, 6.0395e-02, -3.8386e-02, -8.7508e-02,
-6.1256e-02, -6.3347e-03, -3.6381e-03, 8.2017e-02, 7.5017e-02,
-3.0227e-02, -1.6147e-01, 5.2752e-02, -1.7114e-01, -1.0651e-01,
4.2548e-02, -7.4882e-02, -1.2945e-01, 1.9034e-02, -8.8790e-02,
-1.4525e-01, -1.2573e-01, -2.6728e-02, -2.5224e-02, -3.1653e-01,
8.1042e-02, -1.5198e-01, -3.8757e-02, -1.7580e-01, 6.2542e-02,
5.6218e-02, -5.4357e-02, -1.2610e-01, 5.5038e-02, -9.7318e-02,
1.9284e-01, 1.1065e-01, -5.6410e-02, -1.7474e-02, -1.0433e-01,
8.9177e-02, -7.3771e-02, -1.5495e-01, 1.1680e-01, -5.6681e-03],
[ 1.9186e-01, 1.0223e-01, 1.0701e-01, 2.7541e-01, -7.9669e-01,
2.0961e-01, 9.7709e-02, -7.4805e-01, 1.8823e-01, 2.1411e-01,
2.1111e-02, 2.3222e-01, 1.8330e-01, 1.1841e-01, -3.1634e-02,
1.5799e-01, 4.0911e-02, 8.6514e-02, 2.1074e-01, -3.2546e-01,
-4.7492e-02, 6.6884e-02, -1.3732e-01, -4.2719e-02, 6.7551e-02,
-2.7026e-02, 1.0300e-01, -3.1153e-02, 6.7054e-02, -6.5785e-02,
-3.3956e-01, -6.5564e-02, -2.8902e-03, -5.6428e-02, -7.6258e-01,
1.1608e-01, 4.2748e-01, 6.4623e-02, 8.0076e-02, 4.0094e-02,
1.4439e-01, 1.7171e-01, -4.2383e-01, -7.9801e-02, 1.9432e-02,
-3.3688e-02, -3.8898e-02, 3.7044e-02, 2.2053e-01, 1.0799e-04,
-7.8533e-02, 2.4398e-01, -4.2730e-01, 4.2746e-02, 1.2053e-01,
-2.5344e-02, -3.3898e-01, 1.2668e-01, -1.0610e-01, -2.4014e-01,
-8.3949e-03, -6.4227e-02, 2.2441e-02, -1.1778e-01, 6.6599e-02,
3.1477e-02, -1.9016e-04, 8.3123e-02, 1.5891e-01, 1.7011e-01,
1.5146e-01, -7.9873e-01, -3.9691e-01, -1.3733e-01, 7.5391e-02,
2.0980e-01, -3.2066e-01, -1.3985e-02, 7.8271e-03, -6.8302e-03,
1.7113e-01, 1.0262e-01, -1.6727e-03, -6.6286e-02, -2.0698e-01,
1.8812e-01, -4.9481e-02, -2.0186e-01, 2.3084e-01, -1.0781e-01,
-6.5467e-01, -6.3727e-04, 1.1265e-01, -2.4518e-01, -7.0207e-03,
7.2730e-02, 1.9851e-01, -6.7493e-02, 6.7490e-01, -4.7554e-04],
[ 8.5253e-02, 3.6790e-01, -2.0190e-01, -1.2884e-02, 2.7314e-01,
1.5281e-01, 1.3067e-01, -1.1924e-01, 4.6366e-01, 5.2650e-02,
3.8249e-02, 6.9268e-02, 4.2244e-02, 5.4462e-02, 4.3517e-01,
1.4550e-01, 1.2783e-02, -3.6348e-01, -4.2320e-02, 1.1229e-01,
1.2455e-01, -3.9866e-01, 1.1623e-01, -9.5708e-02, -3.7096e-01,
-3.1879e-01, -1.1747e-01, -3.0447e-01, 1.0429e-01, 2.8921e-01,
-2.6497e-01, 2.5279e-01, 7.5485e-02, -2.3777e-01, 1.3919e-01,
4.4179e-01, 7.6398e-02, 1.2101e-01, 1.2285e-01, -4.5768e-02,
1.6343e-01, 5.6805e-02, 6.3629e-01, -9.1578e-02, 2.1487e-01,
4.9422e-02, 9.7040e-02, 1.7530e-02, 4.7804e-02, 1.1712e-01,
1.8626e-01, -6.1781e-03, -3.3298e-01, -2.0199e-01, 2.8425e-01,
4.5014e-02, -1.7341e-01, -5.8929e-04, -8.4011e-04, 2.6978e-01,
4.0128e-01, 1.3957e-01, 1.9709e-02, 1.9069e-01, 1.9123e-01,
-5.0225e-02, -4.7232e-02, -4.8383e-02, -6.6502e-02, -3.8041e-02,
-2.3345e-02, 2.3051e-01, -4.9389e-02, -4.8033e-02, 1.1691e-01,
5.1496e-01, 1.0084e-01, 7.1630e-02, 3.2055e-02, -3.3248e-01,
-2.0284e-02, -1.6052e-01, -1.9679e-01, -4.4816e-02, 4.8449e-02,
3.4831e-01, 1.8643e-01, 2.9630e-01, -1.4649e-01, -2.7486e-01,
1.6517e-01, 3.6800e-02, 2.5259e-02, -9.9867e-02, 4.6995e-02,
1.6073e-01, 1.3008e-01, -4.7025e-02, 3.1125e-01, 6.8535e-02],
[-6.0539e-02, -4.4903e-03, -5.5134e-02, 2.5655e-01, 1.9688e-01,
1.5103e-01, -1.1826e-02, 5.3932e-01, 7.4606e-01, -1.1625e-01,
2.0639e-01, 1.0444e-01, -1.2815e-01, -8.3173e-02, 7.3563e-01,
-5.1470e-02, 1.3627e-01, -1.0789e-01, 7.4358e-02, -6.6397e-02,
-1.4015e-03, -1.7172e-01, -1.2843e-02, 4.1225e-01, -1.1703e-01,
1.9197e-01, 1.6208e-01, -1.0199e-01, 1.2796e-02, -7.4289e-02,
-1.6336e-01, 1.0056e-01, -1.6830e-02, -4.2748e-02, 2.5940e-01,
-3.4320e-01, -8.1927e-02, -2.8906e-02, -2.6072e-02, 5.1455e-02,
7.6014e-02, 7.1832e-02, -6.1156e-01, 3.4221e-02, 1.4976e-01,
-1.4457e-01, -3.2255e-03, -3.5813e-02, -3.8536e-01, 2.4207e-02,
5.6061e-02, 5.4010e-02, 2.2706e-01, -2.7755e-02, -2.5191e-01,
1.4227e-01, 1.6484e-01, 4.1759e-03, 2.1995e-01, -1.0093e-01,
5.4968e-01, -3.1629e-01, 1.4322e-01, 2.0420e-02, 4.4953e-02,
-1.2201e-01, 1.6357e-02, -5.3477e-02, 3.4353e-02, 1.7106e-02,
2.1129e-02, 4.8434e-02, 3.3463e-01, -4.9149e-03, -2.9105e-01,
-4.1212e-01, -1.6414e-01, 5.6706e-02, -9.4353e-02, 4.0012e-01,
1.1213e-01, -9.4816e-03, -9.8370e-02, -2.9450e-01, -9.1417e-03,
-8.6727e-02, -9.5072e-02, 3.8655e-01, -7.0459e-02, 5.6630e-01,
-3.7482e-02, -8.3377e-02, 4.8785e-02, -2.2670e-01, -2.7037e-02,
-8.2207e-02, 4.1002e-01, 5.0883e-02, -3.9839e-01, -4.4907e-02],
[ 1.6589e-01, 9.1893e-02, 1.9530e-01, 9.2547e-02, 1.7666e-01,
-2.1956e-01, 5.9964e-02, 1.5438e-01, 1.8840e-01, -5.4930e-02,
2.8582e-02, 4.0812e-02, 1.9048e-01, 2.9546e-02, 2.1266e-01,
3.4931e-02, -2.8150e-02, -4.4153e-02, 9.9252e-02, 1.8868e-01,
1.3800e-01, 2.0872e-02, 1.5372e-01, 2.1436e-02, 6.6080e-01,
-2.0198e-01, -4.5529e-01, 3.0689e-01, -2.0198e-02, 2.4786e-01,
3.5457e-01, 2.6853e-01, -5.6232e-02, 3.6438e-01, 6.8775e-02,
1.0726e-01, 1.6393e-01, -5.9914e-03, -4.6087e-03, 2.8990e-02,
-4.6558e-02, 9.0374e-02, 6.3195e-01, -6.4279e-02, 2.0935e-02,
9.9869e-02, -4.7908e-02, -4.9447e-02, 1.0565e-01, 1.4623e-01,
-1.7076e-01, 1.9899e-01, -1.0314e-02, -3.7178e-02, -1.6649e-01,
5.4754e-02, 2.6698e-01, 8.6397e-02, -4.9514e-02, -2.0020e-01,
9.0062e-02, 4.2456e-01, 8.0474e-02, -8.2473e-02, 1.6329e-01,
1.9541e-01, -5.2071e-02, 1.3438e-01, 5.3355e-02, -4.0501e-02,
7.2137e-02, 9.4610e-02, -1.2058e-01, 1.0757e-02, 3.5982e-01,
-7.8495e-02, -1.7245e-01, -1.8146e-02, -9.3419e-02, 8.2358e-02,
1.1147e-01, -5.9163e-02, 3.2430e-01, 5.5334e-02, 1.8202e-01,
1.1893e-01, 1.5508e-02, 1.8176e-01, -1.0215e-02, -2.3243e-01,
8.3380e-02, 8.7213e-02, 6.3830e-02, 5.1061e-01, 1.6680e-01,
1.0405e-01, -1.7493e-01, 1.2890e-01, -2.6674e-01, 1.2023e-01],
[ 8.3150e-02, -3.8342e-02, 9.8947e-02, 2.5845e-02, 7.0042e-02,
-3.4700e-01, -3.1025e-02, -4.6168e-01, 5.8525e-02, 4.1622e-02,
7.0948e-03, -3.8280e-02, 1.4927e-01, -5.3696e-02, 8.4766e-02,
-9.2226e-02, -3.8204e-02, 2.7261e-01, 1.0362e-01, 5.6582e-01,
1.5963e-01, 3.7851e-01, 8.8275e-02, -9.4182e-03, -8.8659e-02,
-1.0709e-01, -2.6623e-01, -1.2749e-01, -8.3864e-02, -2.0092e-01,
2.4163e-01, -2.1115e-01, -1.4877e-01, -1.3554e-01, -6.6022e-02,
-3.8177e-01, 2.8522e-01, -1.3721e-01, -9.3008e-02, -7.0277e-03,
-4.5320e-02, 7.0163e-02, -3.9012e-01, -1.0984e-02, -2.6638e-01,
2.8035e-02, -6.4254e-02, 4.2502e-02, -1.0570e-01, -4.1955e-02,
-2.9909e-01, 1.5473e-01, -2.2444e-01, 3.3241e-01, -6.4602e-01,
2.5705e-04, 1.9962e-01, 1.7638e-02, -1.9582e-01, -3.2925e-01,
2.2283e-02, 7.8028e-03, 1.5140e-01, -6.0125e-02, -2.0571e-02,
2.1563e-01, 1.5333e-02, 9.4160e-02, -6.0160e-02, 2.3433e-02,
1.4381e-01, -1.0833e-01, -2.1113e-01, -4.0786e-02, 5.4556e-01,
-2.5435e-01, 2.0931e-01, -1.6533e-01, 5.7826e-02, 2.4129e-01,
-6.5470e-03, 1.0573e-01, 4.2425e-02, 1.5857e-01, -2.6230e-01,
6.0087e-02, 9.3488e-02, 5.1482e-02, -2.3391e-02, 1.5080e-01,
-8.1993e-03, 1.0456e-01, 9.1285e-02, 1.5626e-01, 1.1439e-02,
4.7808e-03, -2.6726e-01, 5.3531e-02, -3.6573e-01, -2.3823e-02],
[ 1.2500e-02, 3.7133e-01, -1.4034e-01, 3.9198e-02, -3.6731e-02,
4.5655e-01, -7.4721e-02, 4.2471e-01, 4.9195e-02, -3.0580e-02,
3.3413e-02, -6.8495e-02, -1.9100e-01, 3.6335e-02, -1.9796e-01,
-1.0472e-01, -4.4844e-03, 1.5869e-02, -1.5608e-01, -3.7370e-01,
1.7851e-01, 1.1554e-01, -2.4501e-03, 2.4786e-01, -1.8886e-01,
-3.4925e-02, 8.6353e-02, -5.1049e-03, 2.1715e-02, 1.6269e-01,
-1.4769e-01, 6.4491e-02, 5.3724e-02, 2.4163e-01, 1.0518e-02,
-1.2475e-01, 7.9642e-02, 1.0490e-01, -6.9637e-02, 8.0978e-02,
1.0364e-01, -7.0763e-02, 1.7290e-02, 6.2098e-02, 7.8224e-02,
1.6668e-02, -3.3680e-02, -7.9051e-02, -4.8204e-01, 1.4291e-01,
5.3678e-01, -5.3431e-02, 2.7816e-01, 1.6967e-01, 5.4366e-01,
6.1173e-02, 1.3306e-01, -8.3422e-02, -3.2110e-02, 8.0666e-01,
2.1188e-01, 2.7693e-02, -1.1357e-01, -1.6673e-03, -9.6522e-05,
-9.0332e-02, -1.7775e-03, -3.8032e-02, -1.5212e-01, -1.2126e-01,
-3.0774e-02, -2.2333e-01, 3.0382e-02, 1.7606e-01, -3.3427e-01,
-1.1134e-01, -4.7750e-01, 1.5241e-01, 1.9317e-01, 2.0194e-02,
-1.1334e-01, 1.5329e-01, -7.5817e-02, -2.4383e-01, 3.1356e-01,
-5.0673e-02, 3.6370e-02, -3.7387e-03, -6.6226e-02, -8.6424e-02,
-1.3998e-01, -1.2383e-02, -3.8332e-02, -2.1968e-01, 1.0451e-01,
-9.3379e-02, -1.7200e-01, -1.1332e-01, 3.6202e-01, -3.0144e-02],
[ 4.6659e-02, -2.4959e-04, 2.3426e-01, -5.0069e-02, 4.5034e-01,
-2.9681e-01, 9.8321e-02, -1.0126e-01, 1.7274e-01, 9.9580e-02,
2.6660e-02, 1.3632e-01, 3.6400e-01, -6.2227e-04, 3.3356e-01,
3.1854e-01, 2.5345e-02, -2.1220e-01, 1.0711e-01, -3.2063e-01,
8.1584e-02, -1.1157e-01, 5.7534e-02, -7.1183e-02, 2.2101e-01,
1.1482e-01, -1.6558e-01, 1.8002e-01, 1.1658e-01, -1.7561e-01,
3.7775e-01, -7.2300e-02, 2.2794e-01, 1.2333e-01, 8.2356e-02,
-1.9581e-02, -4.1038e-01, 3.5549e-02, 4.5166e-02, 7.3371e-03,
1.2843e-01, -4.4556e-02, -1.9091e-01, -2.5438e-02, 1.3717e-01,
-7.4617e-02, -3.7626e-02, 1.4992e-01, 2.9686e-01, -2.1618e-02,
-2.0024e-02, 1.1072e-01, 1.8701e-01, -1.2697e-01, -1.9988e-01,
6.1961e-02, 2.3703e-01, -3.5152e-02, -4.0518e-02, -8.4783e-02,
1.6145e-02, -1.6552e-01, -2.3104e-03, 1.8262e-01, 7.8389e-02,
-5.8572e-02, 6.0145e-02, 1.4109e-01, 2.5367e-03, -7.1306e-02,
6.4362e-02, 3.9111e-01, 1.8736e-02, 1.1172e-01, 2.6443e-02,
6.5434e-01, 8.7661e-01, -7.2438e-02, -2.7331e-02, 6.6680e-02,
-1.2558e-01, -5.2908e-02, 1.6369e-01, 7.8277e-01, -1.8979e-02,
1.9800e-01, 2.7563e-01, 1.6517e-01, -4.0421e-02, -9.4904e-02,
2.2520e-01, -1.0231e-02, 2.7240e-03, 6.0567e-02, -3.0439e-02,
8.6276e-02, -5.8198e-02, 1.4592e-01, 2.6837e-01, 4.3547e-02],
[ 8.4125e-02, -5.1851e-02, 1.1598e-01, -2.1932e-01, -1.6354e-01,
7.6671e-02, -4.1862e-01, 6.3160e-01, -1.0527e+00, 7.4813e-02,
-3.1027e-01, 1.2660e-02, -1.6032e-01, -9.4683e-02, -1.7580e+00,
-4.9546e-01, 9.8908e-02, 5.0030e-01, -3.1702e-02, -3.7284e-01,
-5.0620e-02, -9.6933e-02, 6.6228e-02, -3.1463e-01, 4.5886e-02,
4.0198e-01, 2.0504e-02, -6.8307e-02, 6.4647e-02, 2.3587e-02,
2.2201e-02, -9.0347e-02, 2.3810e-02, 1.7218e-01, 1.6680e-01,
1.7557e-01, -8.1318e-01, -1.2789e-01, -1.0980e-01, 8.1767e-02,
8.8020e-02, -1.1889e-01, 3.1082e-01, 1.7859e-01, -5.2875e-01,
-8.5934e-02, 1.4063e-02, -1.4435e-01, 3.1226e-02, 1.3918e-01,
2.9591e-01, 9.6788e-02, -6.5919e-02, 4.9293e-01, 4.5455e-01,
-8.7932e-02, 3.0168e-01, -3.2154e-02, 7.4244e-02, -2.0943e-02,
-8.3931e-01, -1.1242e-01, -5.2607e-02, 9.7797e-02, -7.4917e-02,
1.2305e-01, -9.9438e-02, -6.5583e-02, -3.5105e-02, -9.1418e-02,
-2.2342e-02, 6.2180e-02, 6.4697e-01, 4.4723e-02, -2.9072e-01,
2.8315e-01, -3.1691e-01, 6.4718e-02, 2.0533e-01, -2.2795e-01,
-9.5358e-02, 4.9034e-02, -3.0270e-02, -2.4483e-01, -4.5891e-02,
1.2248e-02, -6.0542e-02, -2.5736e-01, -6.9711e-02, 4.1689e-01,
-4.3357e-02, 3.9361e-02, 2.3762e-03, 2.7235e-02, -1.7730e-01,
-1.9130e-01, 2.8159e-01, -4.1389e-02, -3.6894e-01, -4.5783e-02]])), ('classification_layer.bias', tensor([-0.0437, -0.9516, 0.6422, 0.0229, 0.5154, -0.4206, -0.7157, -0.3382,
0.6277, 0.6935])), ('hidden1_bn.weight', tensor([0.4632, 0.0693, 0.4676, 0.0912, 0.8992, 0.0107, 0.2437, 0.3002, 0.5073,
0.3785, 0.4168, 0.2188, 0.5564, 0.3978, 0.5550, 0.4008, 0.9480, 0.2032,
0.0950, 0.9562, 0.2036, 0.1049, 0.8202, 0.6890, 0.1459, 0.5184, 0.9886,
0.0288, 0.3081, 0.5502, 0.3616, 0.2362, 0.5752, 0.7971, 0.6464, 0.6093,
0.6319, 0.6932, 0.5754, 0.7061, 0.1426, 0.5505, 0.6314, 0.5166, 0.7559,
0.6663, 0.3720, 0.0903, 0.4769, 0.2049, 0.6687, 0.4565, 0.7206, 0.8735,
0.6352, 0.6227, 0.4973, 0.2230, 0.2906, 0.7680, 0.3271, 0.6717, 0.9873,
0.8300, 0.3160, 0.3024, 0.0135, 0.3432, 0.9397, 0.4456, 0.4240, 0.2521,
0.1084, 0.1101, 0.3857, 0.2515, 0.6182, 0.7026, 0.6060, 0.8159, 0.6365,
0.8266, 0.8583, 0.7963, 0.3495, 0.1919, 0.7465, 0.2586, 0.7636, 0.6191,
0.7115, 0.4252, 0.6900, 0.5011, 0.2227, 0.4763, 0.6764, 0.1176, 0.8967,
0.5297])), ('hidden1_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden1_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden1_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden1_bn.num_batches_tracked', tensor(0)), ('hidden2_bn.weight', tensor([0.6475, 0.1476, 0.0940, 0.0261, 0.5767, 0.7540, 0.3665, 0.0262, 0.0355,
0.0341, 0.4112, 0.9077, 0.4641, 0.0622, 0.9530, 0.4326, 0.0157, 0.4790,
0.4019, 0.4963, 0.8927, 0.4591, 0.3768, 0.4285, 0.1262, 0.2269, 0.4734,
0.1281, 0.0630, 0.6728, 0.9172, 0.4068, 0.5742, 0.0570, 0.9664, 0.5743,
0.4197, 0.6693, 0.5954, 0.7664, 0.1576, 0.5143, 0.3858, 0.2389, 0.1980,
0.2186, 0.4176, 0.2282, 0.3032, 0.9754, 0.9064, 0.3265, 0.1897, 0.6833,
0.7502, 0.4992, 0.9084, 0.7501, 0.7682, 0.3088, 0.6656, 0.0010, 0.0890,
0.2017, 0.2345, 0.2617, 0.5082, 0.8750, 0.8884, 0.8557, 0.7229, 0.0018,
0.9673, 0.8800, 0.2885, 0.0765, 0.1365, 0.5506, 0.2979, 0.4409, 0.2962,
0.7135, 0.5460, 0.9984, 0.3038, 0.4950, 0.1830, 0.8730, 0.7314, 0.5932,
0.6564, 0.2105, 0.9765, 0.2568, 0.7231, 0.3166, 0.0087, 0.1504, 0.8817,
0.2414])), ('hidden2_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden2_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden2_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden2_bn.num_batches_tracked', tensor(0)), ('hidden3_bn.weight', tensor([0.0183, 0.1703, 0.0816, 0.0073, 0.7481, 0.0045, 0.3684, 0.6449, 0.7166,
0.2513, 0.7362, 0.9478, 0.0579, 0.8907, 0.5410, 0.9047, 0.9532, 0.4001,
0.8993, 0.2649, 0.4780, 0.8824, 0.5346, 0.5739, 0.8368, 0.6350, 0.3515,
0.2345, 0.9436, 0.4721, 0.3576, 0.0944, 0.5854, 0.5526, 0.5765, 0.7673,
0.8020, 0.7514, 0.4501, 0.0259, 0.0312, 0.5814, 0.6849, 0.7483, 0.6331,
0.8805, 0.2422, 0.1488, 0.3588, 0.2841, 0.4533, 0.7722, 0.6284, 0.2670,
0.0777, 0.8324, 0.4633, 0.8356, 0.1231, 0.7873, 0.4009, 0.3379, 0.4591,
0.0550, 0.4897, 0.8159, 0.8478, 0.6804, 0.6224, 0.7077, 0.6013, 0.7264,
0.9880, 0.2310, 0.6292, 0.1254, 0.8500, 0.7606, 0.5549, 0.7801, 0.0566,
0.1811, 0.6724, 0.4320, 0.2750, 0.8118, 0.7839, 0.6223, 0.0229, 0.8085,
0.9893, 0.1615, 0.7277, 0.8736, 0.1750, 0.1782, 0.1602, 0.5801, 0.4103,
0.8275])), ('hidden3_bn.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden3_bn.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])), ('hidden3_bn.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('hidden3_bn.num_batches_tracked', tensor(0))])
# initailze model by saved parameters
new_model = FeedForwardNeuralNetwork(input_size, hidden_size, output_size)
new_model.load_state_dict(saved_parametes)
3.4.2
Use the evaluate
function to predict accuracy and loss of the new_model
on the test_loader
.
# TODO
new_test_loss, new_test_accuracy = evaluate(test_loader, new_model, loss_fn)
message = 'Average loss: {:.4f}, Accuracy: {:.4f}'.format(new_test_loss, new_test_accuracy)
print(message)
Average loss: 14.7253, Accuracy: 95.2300
4. Training Advanced
4.1 l2_norm
we could minimize the regularization term below by use in SGD optimizer
\begin{equation}
L_norm = {\sum_{i=1}{m}{\theta_{i}{2}}}
\end{equation}
4.1.1 l2_norm = 0.01
set l2_norm=0.01, let's train and see
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0.01 # use l2 penalty
get_grad = False
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.9034, Accuracy: 74.8583
Epoch: 1/5. Validation set: Average loss: 0.9461, Accuracy: 75.4200
Epoch: 2/5. Train set: Average loss: 0.6313, Accuracy: 86.2433
Epoch: 2/5. Validation set: Average loss: 0.4580, Accuracy: 86.5500
Epoch: 3/5. Train set: Average loss: 0.4135, Accuracy: 89.0417
Epoch: 3/5. Validation set: Average loss: 0.3631, Accuracy: 89.3100
Epoch: 4/5. Train set: Average loss: 0.3531, Accuracy: 90.2200
Epoch: 4/5. Validation set: Average loss: 0.3268, Accuracy: 90.4500
Epoch: 5/5. Train set: Average loss: 0.3227, Accuracy: 90.9317
Epoch: 5/5. Validation set: Average loss: 0.3030, Accuracy: 91.1100
image
image
4.1.2 Problem 5
Consider the influence of regular items in loss proportion. L2_norm = 1
was used to train the model.
Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.
# TODO
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 1 # use l2 penalty
get_grad = False
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
# TODO
# Train
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 2.3071, Accuracy: 11.2367
Epoch: 1/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 2/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 2/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 3/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 3/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 4/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 4/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
Epoch: 5/5. Train set: Average loss: 2.3073, Accuracy: 11.2367
Epoch: 5/5. Validation set: Average loss: 2.3024, Accuracy: 11.3500
image
image
We can see that if the l2 penalty is too big, the accuracy can be significantly affected.
4.2 dropout
During training, randomly zeroes some of the elements of the input tensor with probability p using samples from a Bernoulli distribution.
Each channel will be zeroed out independently on every forward call.
Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
# Set dropout to True and probability = 0.5
model.set_use_dropout(True)
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 0.3335, Accuracy: 92.6233
Epoch: 1/5. Validation set: Average loss: 0.2438, Accuracy: 92.5300
Epoch: 2/5. Train set: Average loss: 0.3065, Accuracy: 93.3100
Epoch: 2/5. Validation set: Average loss: 0.2221, Accuracy: 93.1600
Epoch: 3/5. Train set: Average loss: 0.2794, Accuracy: 93.8617
Epoch: 3/5. Validation set: Average loss: 0.2036, Accuracy: 93.6500
Epoch: 4/5. Train set: Average loss: 0.2576, Accuracy: 94.3500
Epoch: 4/5. Validation set: Average loss: 0.1894, Accuracy: 94.1400
Epoch: 5/5. Train set: Average loss: 0.2373, Accuracy: 94.7400
Epoch: 5/5. Validation set: Average loss: 0.1768, Accuracy: 94.5100
image
image
4.3 batch_normalization
Batch normalization is a technique for improving the performance and stability of artificial neural networks
\begin{equation}
y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} * \gamma + \beta,
\end{equation}
and are learnable parameters
Hints: because jupyter has context on variables, the model and the optimizer needs to be restated. The model and optimizer can be redefined using the following code. Note that the default initialization is used here.
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
model.set_use_bn(True)
model.use_bn
True
train_accs, train_losses, test_losses, test_accs = fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.0761, Accuracy: 91.1733
Epoch: 1/5. Validation set: Average loss: 0.4680, Accuracy: 91.1000
Epoch: 2/5. Train set: Average loss: 0.3410, Accuracy: 94.5100
Epoch: 2/5. Validation set: Average loss: 0.2490, Accuracy: 94.1800
Epoch: 3/5. Train set: Average loss: 0.2136, Accuracy: 95.9850
Epoch: 3/5. Validation set: Average loss: 0.1795, Accuracy: 95.5600
Epoch: 4/5. Train set: Average loss: 0.1589, Accuracy: 96.8617
Epoch: 4/5. Validation set: Average loss: 0.1459, Accuracy: 96.3400
Epoch: 5/5. Train set: Average loss: 0.1268, Accuracy: 97.4000
Epoch: 5/5. Validation set: Average loss: 0.1269, Accuracy: 96.6400
image
image
4.4 data augmentation
data augmentation can be more complicated to gain a better generalization on test dataset
# only add random horizontal flip
train_transform_1 = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
# Normalize a tensor image with mean and standard deviation
transforms.Normalize((0.1307,), (0.3081,))
])
# only add random crop
train_transform_2 = transforms.Compose([
transforms.RandomCrop(size=[28,28], padding=4),
transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
# Normalize a tensor image with mean and standard deviation
transforms.Normalize((0.1307,), (0.3081,))
])
# add random horizontal flip and random crop
train_transform_3 = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=[28,28], padding=4),
transforms.ToTensor(), # Convert a PIL Image or numpy.ndarray to tensor.
# Normalize a tensor image with mean and standard deviation
transforms.Normalize((0.1307,), (0.3081,))
])
# reload train_loader using trans
train_dataset_1 = torchvision.datasets.MNIST(root='./data',
train=True,
transform=train_transform_1,
download=False)
train_loader_1 = torch.utils.data.DataLoader(dataset=train_dataset_1,
batch_size=batch_size,
shuffle=True)
print(train_dataset_1)
Dataset MNIST
Number of datapoints: 60000
Split: train
Root Location: ./data
Transforms (if any): Compose(
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=(0.1307,), std=(0.3081,))
)
Target Transforms (if any): None
### Hyper parameters
batch_size = 128
n_epochs = 5
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # without using l2 penalty
get_grad = False
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
train_accs, train_losses, test_losses, test_accs = fit(train_loader_1, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 2.0015, Accuracy: 66.7167
Epoch: 1/5. Validation set: Average loss: 1.2088, Accuracy: 67.6700
Epoch: 2/5. Train set: Average loss: 0.8502, Accuracy: 78.9600
Epoch: 2/5. Validation set: Average loss: 0.6482, Accuracy: 79.7700
Epoch: 3/5. Train set: Average loss: 0.6221, Accuracy: 82.1050
Epoch: 3/5. Validation set: Average loss: 0.5469, Accuracy: 82.7900
Epoch: 4/5. Train set: Average loss: 0.5425, Accuracy: 83.7417
Epoch: 4/5. Validation set: Average loss: 0.4863, Accuracy: 84.2700
Epoch: 5/5. Train set: Average loss: 0.4813, Accuracy: 85.9383
Epoch: 5/5. Validation set: Average loss: 0.4333, Accuracy: 86.1800
image
image
4.5 Problem 6
Use train_transform_2
and train_transform_3
provided, reload train_loader
and train with fit
.
Hints: because jupyter has context for variables, the model, the optimizer, needs to be re-declared. Note that the default initialization is used here.
# TODO
# reload train_loader using train_transform_2
train_dataset_2 = torchvision.datasets.MNIST(root='./data',
train=True,
transform=train_transform_2,
download=False)
train_loader_2 = torch.utils.data.DataLoader(dataset=train_dataset_2,
batch_size=batch_size,
shuffle=True)
train_accs, train_losses, test_losses, test_accs = fit(train_loader_2, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 1.3406, Accuracy: 62.0983
Epoch: 1/5. Validation set: Average loss: 0.9176, Accuracy: 74.7300
Epoch: 2/5. Train set: Average loss: 1.0130, Accuracy: 72.4767
Epoch: 2/5. Validation set: Average loss: 0.7144, Accuracy: 79.6100
Epoch: 3/5. Train set: Average loss: 0.7818, Accuracy: 78.9767
Epoch: 3/5. Validation set: Average loss: 0.5295, Accuracy: 84.8800
Epoch: 4/5. Train set: Average loss: 0.6261, Accuracy: 82.3433
Epoch: 4/5. Validation set: Average loss: 0.4338, Accuracy: 87.4800
Epoch: 5/5. Train set: Average loss: 0.5252, Accuracy: 85.5233
Epoch: 5/5. Validation set: Average loss: 0.3735, Accuracy: 89.0300
image
image
# TODO
# reload train_loader using train_transform_3
train_dataset_3 = torchvision.datasets.MNIST(root='./data',
train=True,
transform=train_transform_3,
download=False)
train_loader_3 = torch.utils.data.DataLoader(dataset=train_dataset_3,
batch_size=batch_size,
shuffle=True)
train_accs, train_losses, test_losses, test_accs = fit(train_loader_3, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
show_curve(train_accs, test_accs, 'Accs')
show_curve(train_losses, test_losses, 'Losses')
Epoch: 1/5. Train set: Average loss: 0.7662, Accuracy: 78.1667
Epoch: 1/5. Validation set: Average loss: 0.4631, Accuracy: 86.4500
Epoch: 2/5. Train set: Average loss: 0.6339, Accuracy: 80.5283
Epoch: 2/5. Validation set: Average loss: 0.4665, Accuracy: 86.0200
Epoch: 3/5. Train set: Average loss: 0.5718, Accuracy: 81.7750
Epoch: 3/5. Validation set: Average loss: 0.4170, Accuracy: 86.5700
Epoch: 4/5. Train set: Average loss: 0.5321, Accuracy: 83.5950
Epoch: 4/5. Validation set: Average loss: 0.3840, Accuracy: 87.6100
Epoch: 5/5. Train set: Average loss: 0.5019, Accuracy: 83.8067
Epoch: 5/5. Validation set: Average loss: 0.3902, Accuracy: 87.7000
image
image
5. Visualization of training and validation phase
We could use tensorboard to visualize our training and test phase.
You could find example here
6. Gradient explosion and vanishing
We have embedded code which shows grad for hidden2 and hidden3 layer. By observing their grad changes, we can
see whether gradient is normal or not.
For plot grad changes, you need to set get_grad=True in fit function
### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # use l2 penalty
get_grad = True
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad)
Epoch: 1/15. Train set: Average loss: 1.8883, Accuracy: 77.2633
Epoch: 1/15. Validation set: Average loss: 0.8983, Accuracy: 77.9100
Epoch: 2/15. Train set: Average loss: 0.5687, Accuracy: 87.7217
Epoch: 2/15. Validation set: Average loss: 0.4038, Accuracy: 88.0700
Epoch: 3/15. Train set: Average loss: 0.3675, Accuracy: 89.9283
Epoch: 3/15. Validation set: Average loss: 0.3260, Accuracy: 90.1600
Epoch: 4/15. Train set: Average loss: 0.3123, Accuracy: 91.1600
Epoch: 4/15. Validation set: Average loss: 0.2863, Accuracy: 91.4200
Epoch: 5/15. Train set: Average loss: 0.2793, Accuracy: 92.1150
Epoch: 5/15. Validation set: Average loss: 0.2593, Accuracy: 92.2500
Epoch: 6/15. Train set: Average loss: 0.2543, Accuracy: 92.8367
Epoch: 6/15. Validation set: Average loss: 0.2384, Accuracy: 92.8200
Epoch: 7/15. Train set: Average loss: 0.2336, Accuracy: 93.4067
Epoch: 7/15. Validation set: Average loss: 0.2208, Accuracy: 93.4100
Epoch: 8/15. Train set: Average loss: 0.2155, Accuracy: 93.9067
Epoch: 8/15. Validation set: Average loss: 0.2052, Accuracy: 93.8500
Epoch: 9/15. Train set: Average loss: 0.1995, Accuracy: 94.3783
Epoch: 9/15. Validation set: Average loss: 0.1911, Accuracy: 94.1600
Epoch: 10/15. Train set: Average loss: 0.1854, Accuracy: 94.7917
Epoch: 10/15. Validation set: Average loss: 0.1789, Accuracy: 94.5500
Epoch: 11/15. Train set: Average loss: 0.1727, Accuracy: 95.1583
Epoch: 11/15. Validation set: Average loss: 0.1682, Accuracy: 94.8800
Epoch: 12/15. Train set: Average loss: 0.1615, Accuracy: 95.4683
Epoch: 12/15. Validation set: Average loss: 0.1588, Accuracy: 95.1600
Epoch: 13/15. Train set: Average loss: 0.1516, Accuracy: 95.7700
Epoch: 13/15. Validation set: Average loss: 0.1507, Accuracy: 95.3900
Epoch: 14/15. Train set: Average loss: 0.1427, Accuracy: 96.0317
Epoch: 14/15. Validation set: Average loss: 0.1437, Accuracy: 95.6500
Epoch: 15/15. Train set: Average loss: 0.1348, Accuracy: 96.2417
Epoch: 15/15. Validation set: Average loss: 0.1376, Accuracy: 95.8400
([77.26333333333334,
87.72166666666666,
89.92833333333333,
91.16,
92.115,
92.83666666666667,
93.40666666666667,
93.90666666666667,
94.37833333333333,
94.79166666666667,
95.15833333333333,
95.46833333333333,
95.77,
96.03166666666667,
96.24166666666666],
[1.8883255884433403,
0.5687443313117211,
0.36754155533117616,
0.31234517640983445,
0.27934257469625556,
0.25430761317475736,
0.23359582908292356,
0.21554398813690895,
0.1995451689307761,
0.1853731685023532,
0.17268824516835377,
0.16149521451921034,
0.1515944946843844,
0.142730517917846,
0.13476479675971034],
[0.8983050381081014,
0.40381407219020626,
0.32599611438905135,
0.2863018473586704,
0.25928632353868664,
0.23837185495450527,
0.22084368661611894,
0.20515649761014346,
0.19110500274956982,
0.17893974940422214,
0.16822792386895494,
0.15882641767870775,
0.15071836245965353,
0.14373108235341084,
0.1375972312651103],
[77.91,
88.07,
90.16,
91.42,
92.25,
92.82,
93.41,
93.85,
94.16,
94.55,
94.88,
95.16,
95.39,
95.65,
95.84])
image
6.1.1 Gradient Vanishing
Set learning=e-10
### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1e-20
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # use l2 penalty
get_grad = True
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=get_grad)
Epoch: 1/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 1/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 2/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 2/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 3/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 3/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 4/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 4/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 5/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 5/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 6/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 6/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 7/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 7/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 8/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 8/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 9/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 9/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 10/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 10/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 11/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 11/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 12/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 12/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 13/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 13/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 14/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 14/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
Epoch: 15/15. Train set: Average loss: 2.3074, Accuracy: 14.6833
Epoch: 15/15. Validation set: Average loss: 2.3011, Accuracy: 15.2900
([14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334,
14.683333333333334],
[2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883,
2.3074400210991883],
[2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037,
2.3010528570489037],
[15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29,
15.29])
image
6.1.2 Gradient Explosion
6.1.2.1 learning rate
set learning rate = 10
### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1.0168
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # not to use l2 penalty
get_grad = True
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=True)
Epoch: 1/15. Train set: Average loss: 2.0630, Accuracy: 26.7583
Epoch: 1/15. Validation set: Average loss: 2.1282, Accuracy: 26.7700
Epoch: 2/15. Train set: Average loss: 2.2670, Accuracy: 10.0900
Epoch: 2/15. Validation set: Average loss: 2.2986, Accuracy: 9.7600
Epoch: 3/15. Train set: Average loss: 2.1061, Accuracy: 18.4283
Epoch: 3/15. Validation set: Average loss: 2.0783, Accuracy: 17.8500
Epoch: 4/15. Train set: Average loss: 2.0247, Accuracy: 19.7433
Epoch: 4/15. Validation set: Average loss: 1.9792, Accuracy: 19.1600
Epoch: 5/15. Train set: Average loss: 1.8996, Accuracy: 27.6817
Epoch: 5/15. Validation set: Average loss: 1.7469, Accuracy: 27.8600
Epoch: 6/15. Train set: Average loss: 1.9673, Accuracy: 19.7900
Epoch: 6/15. Validation set: Average loss: 1.8792, Accuracy: 19.4400
Epoch: 7/15. Train set: Average loss: 1.9726, Accuracy: 19.0433
Epoch: 7/15. Validation set: Average loss: 1.9119, Accuracy: 18.3700
Epoch: 8/15. Train set: Average loss: 1.8971, Accuracy: 19.3833
Epoch: 8/15. Validation set: Average loss: 2.0936, Accuracy: 19.1200
Epoch: 9/15. Train set: Average loss: 2.0608, Accuracy: 21.0750
Epoch: 9/15. Validation set: Average loss: 1.9886, Accuracy: 21.0400
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:80: RuntimeWarning: overflow encountered in square
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:81: RuntimeWarning: overflow encountered in square
Epoch: 10/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 10/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 11/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 11/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 12/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 12/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 13/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 13/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 14/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 14/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 15/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 15/15. Validation set: Average loss: nan, Accuracy: 9.8000
([26.758333333333333,
10.09,
18.428333333333335,
19.743333333333332,
27.68166666666667,
19.79,
19.043333333333333,
19.383333333333333,
21.075,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666],
[2.0630336391110706,
2.267040777664918,
2.1061492269365196,
2.024679694165531,
1.8995909976144123,
1.9672510224020379,
1.9726365409855149,
1.8970981736977894,
2.060765859153536,
nan,
nan,
nan,
nan,
nan,
nan],
[2.1281878012645095,
2.2986260969427565,
2.0783027516135686,
1.9791795317130754,
1.7469420357595515,
1.8791846338706681,
1.9119218029553378,
2.093622092959247,
1.988625618475902,
nan,
nan,
nan,
nan,
nan,
nan],
[26.77,
9.76,
17.85,
19.16,
27.86,
19.44,
18.37,
19.12,
21.04,
9.8,
9.8,
9.8,
9.8,
9.8,
9.8])
image
6.1.2.2 normalization for input data
6.1.2.3 unsuitable weight initialization
### Hyper parameters
batch_size = 128
n_epochs = 15
learning_rate = 1
input_size = 28*28
hidden_size = 100
output_size = 10
l2_norm = 0 # not to use l2 penalty
get_grad = True
# declare a model
model = FeedForwardNeuralNetwork(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
# Cross entropy
loss_fn = torch.nn.CrossEntropyLoss()
# l2_norm can be done in SGD
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=l2_norm)
# reset parameters as 10
def wrong_weight_bias_reset(model):
"""Using normalization with mean=0, std=1 to initialize model's parameter
"""
for m in model.modules():
if isinstance(m, nn.Linear):
# initialize linear layer with mean and std
mean, std = 0, 1
# Initialization method
torch.nn.init.normal_(m.weight, mean, std)
torch.nn.init.normal_(m.bias, mean, std)
wrong_weight_bias_reset(model)
show_weight_bias(model)
/Users/nino/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:2299: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
warnings.warn("This figure includes Axes that are not compatible "
image
fit(train_loader, test_loader, model, loss_fn, optimizer, n_epochs, get_grad=True)
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:80: RuntimeWarning: overflow encountered in square
/Users/nino/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:81: RuntimeWarning: overflow encountered in square
Epoch: 1/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 1/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 2/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 2/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 3/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 3/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 4/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 4/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 5/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 5/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 6/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 6/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 7/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 7/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 8/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 8/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 9/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 9/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 10/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 10/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 11/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 11/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 12/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 12/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 13/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 13/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 14/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 14/15. Validation set: Average loss: nan, Accuracy: 9.8000
Epoch: 15/15. Train set: Average loss: nan, Accuracy: 9.8717
Epoch: 15/15. Validation set: Average loss: nan, Accuracy: 9.8000
([9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666,
9.871666666666666],
[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
[9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8, 9.8])
image
网友评论