![](https://img.haomeiwen.com/i3968643/2859c8b8438d4b0f.png)
简单分类模型没有关注文本的顺序,且存在严重的过拟合
RNN:当前输出依赖于当前输入与之前输出状态
![](https://img.haomeiwen.com/i3968643/e71a85635ab43162.png)
![](https://img.haomeiwen.com/i3968643/66987ed24766bcd6.png)
![](https://img.haomeiwen.com/i3968643/8cbffbff8787dd20.png)
![](https://img.haomeiwen.com/i3968643/aa0230197645d614.png)
![](https://img.haomeiwen.com/i3968643/f7bfb5ccac8d4119.png)
![](https://img.haomeiwen.com/i3968643/85c5ea2b5140b5f8.png)
![](https://img.haomeiwen.com/i3968643/0d110267d59ef603.png)
![](https://img.haomeiwen.com/i3968643/e2445ee89f4a8c40.png)
与顺序有关:用RNN模型
与顺序无关:一般深度学习模型
import torch
import torchtext
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer # 分词工具
from torchtext.vocab import build_vocab_from_iterator # 创建词表工具
train_iter, test_iter = torchtext.datasets.IMDB()
train_iter, test_iter = torchtext.datasets.IMDB()
train_data, test_data = list(train_iter), list(test_iter)
all_classes = set([label for (label, text) in train_data])
num_class = len(all_classes)
tokenizer = get_tokenizer('basic_english') # 分词工具做初始化
def yield_tokens(data):
for _, text in data:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<pad>", "<unk>"], min_freq=5)
vocab.set_default_index(vocab["<unk>"])
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x == 'pos')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def collate_batch(batch):
label_list, text_list = [], []
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
precess_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(precess_text)
label_list = torch.tensor(label_list)
text_list = torch.nn.utils.rnn.pad_sequence(text_list)
return label_list.to(device), text_list.to(device)
train_dataloader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_data, batch_size=256, shuffle=False, collate_fn=collate_batch)
for i, (l, b) in enumerate(train_dataloader):
print(l.size(), b.size())
if i>9:
break
#torch.Size([256]) torch.Size([1113, 256])
# torch.Size([256]) torch.Size([1428, 256])
# torch.Size([256]) torch.Size([1182, 256])
# torch.Size([256]) torch.Size([1187, 256])
# torch.Size([256]) torch.Size([1097, 256])
# torch.Size([256]) torch.Size([1106, 256])
# torch.Size([256]) torch.Size([1351, 256])
# torch.Size([256]) torch.Size([1169, 256])
# torch.Size([256]) torch.Size([1172, 256])
# torch.Size([256]) torch.Size([1164, 256])
# torch.Size([256]) torch.Size([1190, 256])
vocab_size = len(vocab)
embeding_dim = 300
hidden_size = 128
class RNN_Net(nn.Module):
def __init__(self, vocab_size, embeding_dim, hidden_size):
super(RNN_Net, self).__init__()
self.em = nn.Embedding(vocab_size, embeding_dim)
self.rnn = nn.LSTM(embeding_dim, hidden_size)
self.fc1 = nn.Linear(hidden_size, 64)
self.fc2 = nn.Linear(64, 2)
def forward(self, inputs):
x = self.em(inputs)
x, _ = self.rnn(x)
# print(x.size(), o.size())
x = F.relu(self.fc1(x[-1]))
x = self.fc2(x)
return x
model = RNN_Net(vocab_size, embeding_dim, hidden_size).to(device)
loss_fn = nn.CrossEntropyLoss()
from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), betas=(0.5, 0.5), lr=0.01)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
def train(dataloader):
total_acc, total_count, total_loss, = 0, 0, 0
model.train()
for label, text in dataloader:
predicted_label = model(text)
loss = loss_fn(predicted_label, label)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
total_loss += loss.item()*label.size(0)
return total_loss/total_count, total_acc/total_count
def test(dataloader):
model.eval()
total_acc, total_count, total_loss, = 0, 0, 0
with torch.no_grad():
for label, text in dataloader:
predicted_label = model(text)
loss = loss_fn(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
total_loss += loss.item()*label.size(0)
return total_loss/total_count, total_acc/total_count
def fit(epochs, train_dl, test_dl):
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl)
epoch_test_loss, epoch_test_acc = test(test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
exp_lr_scheduler.step()
template = ("epoch:{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
print("Done!")
return train_loss, test_loss, train_acc, test_acc
EPOCHS = 100
train_loss, test_loss, train_acc, test_acc = fit(EPOCHS, train_dataloader, test_dataloader)
epoch: 0, train_loss: 0.69916, train_acc: 50.2% ,test_loss: 0.69378, test_acc: 50.0%
epoch: 1, train_loss: 0.69399, train_acc: 50.8% ,test_loss: 0.69380, test_acc: 50.0%
epoch: 2, train_loss: 0.69348, train_acc: 50.2% ,test_loss: 0.69605, test_acc: 50.0%
epoch: 3, train_loss: 0.69278, train_acc: 49.8% ,test_loss: 0.69515, test_acc: 50.0%
epoch: 4, train_loss: 0.67794, train_acc: 55.2% ,test_loss: 0.66339, test_acc: 62.8%
epoch: 5, train_loss: 0.57792, train_acc: 71.5% ,test_loss: 0.55734, test_acc: 74.4%
epoch: 6, train_loss: 0.47179, train_acc: 80.5% ,test_loss: 0.51183, test_acc: 77.9%
epoch: 7, train_loss: 0.39794, train_acc: 84.5% ,test_loss: 0.51667, test_acc: 78.6%
epoch: 8, train_loss: 0.33833, train_acc: 87.9% ,test_loss: 0.53838, test_acc: 76.9%
epoch: 9, train_loss: 0.31944, train_acc: 88.9% ,test_loss: 0.46782, test_acc: 82.8%
epoch:10, train_loss: 0.28838, train_acc: 90.3% ,test_loss: 0.50602, test_acc: 82.9%
epoch:11, train_loss: 0.25307, train_acc: 91.8% ,test_loss: 0.48785, test_acc: 78.4%
epoch:12, train_loss: 0.22973, train_acc: 92.8% ,test_loss: 0.48692, test_acc: 83.3%
epoch:13, train_loss: 0.20735, train_acc: 93.5% ,test_loss: 0.48483, test_acc: 81.6%
epoch:14, train_loss: 0.20245, train_acc: 93.9% ,test_loss: 0.46851, test_acc: 83.7%
epoch:15, train_loss: 0.13799, train_acc: 96.0% ,test_loss: 0.49551, test_acc: 84.0%
epoch:16, train_loss: 0.10263, train_acc: 97.3% ,test_loss: 0.52576, test_acc: 84.2%
epoch:17, train_loss: 0.08577, train_acc: 97.9% ,test_loss: 0.55930, test_acc: 84.4%
epoch:18, train_loss: 0.07269, train_acc: 98.3% ,test_loss: 0.59011, test_acc: 84.5%
epoch:19, train_loss: 0.06335, train_acc: 98.6% ,test_loss: 0.60821, test_acc: 84.3%
epoch:20, train_loss: 0.05706, train_acc: 98.7% ,test_loss: 0.59220, test_acc: 84.1%
epoch:21, train_loss: 0.05131, train_acc: 98.9% ,test_loss: 0.62850, test_acc: 84.5%
epoch:22, train_loss: 0.04518, train_acc: 99.0% ,test_loss: 0.62696, test_acc: 84.0%
epoch:23, train_loss: 0.04158, train_acc: 99.1% ,test_loss: 0.61498, test_acc: 84.6%
epoch:24, train_loss: 0.03753, train_acc: 99.2% ,test_loss: 0.63472, test_acc: 84.1%
epoch:25, train_loss: 0.03328, train_acc: 99.3% ,test_loss: 0.74586, test_acc: 84.6%
epoch:26, train_loss: 0.03213, train_acc: 99.3% ,test_loss: 0.70182, test_acc: 84.6%
epoch:27, train_loss: 0.02766, train_acc: 99.4% ,test_loss: 0.71342, test_acc: 84.6%
epoch:28, train_loss: 0.02531, train_acc: 99.4% ,test_loss: 0.76255, test_acc: 84.4%
epoch:29, train_loss: 0.02227, train_acc: 99.4% ,test_loss: 0.74665, test_acc: 84.7%
epoch:30, train_loss: 0.01811, train_acc: 99.5% ,test_loss: 0.79624, test_acc: 84.7%
epoch:31, train_loss: 0.01743, train_acc: 99.5% ,test_loss: 0.79401, test_acc: 84.7%
epoch:32, train_loss: 0.01700, train_acc: 99.5% ,test_loss: 0.77805, test_acc: 84.6%
epoch:33, train_loss: 0.01606, train_acc: 99.5% ,test_loss: 0.88634, test_acc: 84.8%
epoch:34, train_loss: 0.01522, train_acc: 99.6% ,test_loss: 0.86396, test_acc: 84.8%
epoch:35, train_loss: 0.01509, train_acc: 99.5% ,test_loss: 0.83588, test_acc: 84.8%
epoch:36, train_loss: 0.01447, train_acc: 99.6% ,test_loss: 0.89528, test_acc: 84.9%
epoch:37, train_loss: 0.01449, train_acc: 99.6% ,test_loss: 0.84509, test_acc: 84.7%
epoch:38, train_loss: 0.01412, train_acc: 99.6% ,test_loss: 0.86246, test_acc: 84.7%
epoch:39, train_loss: 0.01367, train_acc: 99.6% ,test_loss: 0.88045, test_acc: 84.7%
epoch:40, train_loss: 0.01302, train_acc: 99.6% ,test_loss: 0.87827, test_acc: 84.6%
epoch:41, train_loss: 0.01281, train_acc: 99.6% ,test_loss: 0.87976, test_acc: 84.7%
epoch:42, train_loss: 0.01269, train_acc: 99.6% ,test_loss: 0.90360, test_acc: 84.7%
epoch:43, train_loss: 0.01230, train_acc: 99.6% ,test_loss: 0.90518, test_acc: 84.8%
epoch:44, train_loss: 0.01219, train_acc: 99.6% ,test_loss: 0.88305, test_acc: 84.6%
epoch:45, train_loss: 0.01102, train_acc: 99.6% ,test_loss: 0.92137, test_acc: 84.7%
epoch:46, train_loss: 0.01103, train_acc: 99.7% ,test_loss: 0.93830, test_acc: 84.7%
epoch:47, train_loss: 0.01136, train_acc: 99.7% ,test_loss: 0.94346, test_acc: 84.7%
epoch:48, train_loss: 0.01122, train_acc: 99.7% ,test_loss: 0.94277, test_acc: 84.7%
epoch:49, train_loss: 0.01131, train_acc: 99.7% ,test_loss: 0.93391, test_acc: 84.6%
epoch:50, train_loss: 0.01118, train_acc: 99.7% ,test_loss: 0.93697, test_acc: 84.6%
epoch:51, train_loss: 0.01134, train_acc: 99.7% ,test_loss: 0.94353, test_acc: 84.6%
epoch:52, train_loss: 0.01143, train_acc: 99.7% ,test_loss: 0.95241, test_acc: 84.6%
epoch:53, train_loss: 0.01114, train_acc: 99.7% ,test_loss: 0.94443, test_acc: 84.6%
epoch:54, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.94568, test_acc: 84.6%
epoch:55, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.95337, test_acc: 84.6%
epoch:56, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.96076, test_acc: 84.6%
epoch:57, train_loss: 0.01115, train_acc: 99.7% ,test_loss: 0.96316, test_acc: 84.6%
epoch:58, train_loss: 0.01116, train_acc: 99.7% ,test_loss: 0.95679, test_acc: 84.6%
epoch:59, train_loss: 0.01068, train_acc: 99.7% ,test_loss: 0.94736, test_acc: 84.5%
epoch:60, train_loss: 0.01125, train_acc: 99.7% ,test_loss: 0.94820, test_acc: 84.5%
epoch:61, train_loss: 0.01074, train_acc: 99.7% ,test_loss: 0.95048, test_acc: 84.6%
epoch:62, train_loss: 0.01092, train_acc: 99.7% ,test_loss: 0.95178, test_acc: 84.6%
epoch:63, train_loss: 0.01093, train_acc: 99.7% ,test_loss: 0.95187, test_acc: 84.6%
epoch:64, train_loss: 0.01095, train_acc: 99.7% ,test_loss: 0.95308, test_acc: 84.6%
epoch:65, train_loss: 0.01046, train_acc: 99.7% ,test_loss: 0.95353, test_acc: 84.6%
epoch:66, train_loss: 0.01062, train_acc: 99.7% ,test_loss: 0.95421, test_acc: 84.6%
epoch:67, train_loss: 0.01178, train_acc: 99.7% ,test_loss: 0.95447, test_acc: 84.6%
epoch:68, train_loss: 0.01068, train_acc: 99.7% ,test_loss: 0.95500, test_acc: 84.6%
epoch:69, train_loss: 0.01233, train_acc: 99.7% ,test_loss: 0.95466, test_acc: 84.6%
epoch:70, train_loss: 0.01081, train_acc: 99.7% ,test_loss: 0.95511, test_acc: 84.6%
epoch:71, train_loss: 0.01076, train_acc: 99.7% ,test_loss: 0.95512, test_acc: 84.6%
epoch:72, train_loss: 0.01071, train_acc: 99.7% ,test_loss: 0.95637, test_acc: 84.6%
epoch:73, train_loss: 0.01064, train_acc: 99.7% ,test_loss: 0.95661, test_acc: 84.6%
epoch:74, train_loss: 0.01066, train_acc: 99.7% ,test_loss: 0.95692, test_acc: 84.6%
epoch:75, train_loss: 0.01073, train_acc: 99.7% ,test_loss: 0.95691, test_acc: 84.6%
epoch:76, train_loss: 0.01051, train_acc: 99.7% ,test_loss: 0.95698, test_acc: 84.6%
epoch:77, train_loss: 0.01069, train_acc: 99.7% ,test_loss: 0.95695, test_acc: 84.6%
epoch:78, train_loss: 0.01095, train_acc: 99.7% ,test_loss: 0.95705, test_acc: 84.6%
epoch:79, train_loss: 0.01087, train_acc: 99.7% ,test_loss: 0.95708, test_acc: 84.6%
epoch:80, train_loss: 0.01084, train_acc: 99.7% ,test_loss: 0.95721, test_acc: 84.6%
epoch:81, train_loss: 0.01073, train_acc: 99.7% ,test_loss: 0.95725, test_acc: 84.6%
epoch:82, train_loss: 0.01055, train_acc: 99.7% ,test_loss: 0.95728, test_acc: 84.6%
epoch:83, train_loss: 0.01067, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:84, train_loss: 0.01059, train_acc: 99.7% ,test_loss: 0.95727, test_acc: 84.6%
epoch:85, train_loss: 0.01121, train_acc: 99.7% ,test_loss: 0.95731, test_acc: 84.6%
epoch:86, train_loss: 0.01114, train_acc: 99.7% ,test_loss: 0.95729, test_acc: 84.6%
epoch:87, train_loss: 0.01157, train_acc: 99.7% ,test_loss: 0.95726, test_acc: 84.6%
epoch:88, train_loss: 0.01057, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:89, train_loss: 0.01149, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:90, train_loss: 0.01070, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:91, train_loss: 0.01275, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:92, train_loss: 0.01057, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:93, train_loss: 0.01061, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:94, train_loss: 0.01045, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:95, train_loss: 0.01157, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:96, train_loss: 0.01182, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:97, train_loss: 0.01054, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:98, train_loss: 0.01092, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:99, train_loss: 0.01061, train_acc: 99.7% ,test_loss: 0.95725, test_acc: 84.6%
Done!
basic_rcnn
import torch
import torchtext
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.data.utils import get_tokenizer # 分词工具
from torchtext.vocab import build_vocab_from_iterator # 创建词表工具
train_iter, test_iter = torchtext.datasets.IMDB()
tokenizer = get_tokenizer('basic_english') # 初始化分词工具
print(tokenizer('This is a book about PyTorch.')) # 在英文语句上调用并打印分词结果
train_iter, test_iter = torchtext.datasets.IMDB()
train_data, test_data = list(train_iter), list(test_iter)
all_classes = set([label for (label, text) in train_data])
num_class = len(all_classes)
tokenizer = get_tokenizer('basic_english') # 分词工具做初始化
def yield_tokens(data):
for _, text in data:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<pad>", "<unk>"], min_freq=5)
# tokenizer('This is a book written by Riyue.guanghua')
# ['this', 'is', 'a', 'book', 'written', 'by', 'riyue', '.', 'guanghua']
# vocab(['this', 'is', 'a', 'book', 'written', 'by', 'riyue', '.', 'guanghua'])
# [14, 10, 6, 276, 430, 40, 1, 3, 1]
vocab.set_default_index(vocab["<unk>"])
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x == 'pos')
# text_pipeline('this is a book about pytorch')
# [14, 10, 6, 276, 50, 1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def collate_batch(batch):
label_list, text_list = [], []
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
precess_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(precess_text)
label_list = torch.tensor(label_list)
text_list = torch.nn.utils.rnn.pad_sequence(text_list)
return label_list.to(device), text_list.to(device)
train_dataloader = DataLoader(train_data, batch_size=256, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_data, batch_size=256, shuffle=False, collate_fn=collate_batch)
vocab_size = len(vocab)
embeding_dim = 300
hidden_size = 128
class RNN_Net(nn.Module):
def __init__(self, vocab_size, embeding_dim, hidden_size):
super(RNN_Net, self).__init__()
self.em = nn.Embedding(vocab_size, embeding_dim)
self.rnn = nn.LSTM(embeding_dim, hidden_size)
self.fc1 = nn.Linear(hidden_size, 64)
self.fc2 = nn.Linear(64, 2)
def forward(self, inputs):
x = self.em(inputs)
x, _ = self.rnn(x)
# print(x.size(), o.size())
x = F.relu(self.fc1(x[-1]))
x = self.fc2(x)
return x
model = RNN_Net(vocab_size, embeding_dim, hidden_size).to(device)
loss_fn = nn.CrossEntropyLoss()
from torch.optim import lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), betas=(0.5, 0.5), lr=0.01)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
def train(dataloader):
total_acc, total_count, total_loss, = 0, 0, 0
model.train()
for label, text in dataloader:
predicted_label = model(text)
loss = loss_fn(predicted_label, label)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
total_loss += loss.item()*label.size(0)
return total_loss/total_count, total_acc/total_count
def test(dataloader):
model.eval()
total_acc, total_count, total_loss, = 0, 0, 0
with torch.no_grad():
for label, text in dataloader:
predicted_label = model(text)
loss = loss_fn(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
total_loss += loss.item()*label.size(0)
return total_loss/total_count, total_acc/total_count
def fit(epochs, train_dl, test_dl):
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl)
epoch_test_loss, epoch_test_acc = test(test_dl)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
exp_lr_scheduler.step()
template = ("epoch:{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
print("Done!")
return train_loss, test_loss, train_acc, test_acc
return train_loss, test_loss, train_acc, test_acc
EPOCHS = 100
train_loss, test_loss, train_acc, test_acc = fit(EPOCHS, train_dataloader, test_dataloader)
epoch: 0, train_loss: 0.69916, train_acc: 50.2% ,test_loss: 0.69378, test_acc: 50.0%
epoch: 1, train_loss: 0.69399, train_acc: 50.8% ,test_loss: 0.69380, test_acc: 50.0%
epoch: 2, train_loss: 0.69348, train_acc: 50.2% ,test_loss: 0.69605, test_acc: 50.0%
epoch: 3, train_loss: 0.69278, train_acc: 49.8% ,test_loss: 0.69515, test_acc: 50.0%
epoch: 4, train_loss: 0.67794, train_acc: 55.2% ,test_loss: 0.66339, test_acc: 62.8%
epoch: 5, train_loss: 0.57792, train_acc: 71.5% ,test_loss: 0.55734, test_acc: 74.4%
epoch: 6, train_loss: 0.47179, train_acc: 80.5% ,test_loss: 0.51183, test_acc: 77.9%
epoch: 7, train_loss: 0.39794, train_acc: 84.5% ,test_loss: 0.51667, test_acc: 78.6%
epoch: 8, train_loss: 0.33833, train_acc: 87.9% ,test_loss: 0.53838, test_acc: 76.9%
epoch: 9, train_loss: 0.31944, train_acc: 88.9% ,test_loss: 0.46782, test_acc: 82.8%
epoch:10, train_loss: 0.28838, train_acc: 90.3% ,test_loss: 0.50602, test_acc: 82.9%
epoch:11, train_loss: 0.25307, train_acc: 91.8% ,test_loss: 0.48785, test_acc: 78.4%
epoch:12, train_loss: 0.22973, train_acc: 92.8% ,test_loss: 0.48692, test_acc: 83.3%
epoch:13, train_loss: 0.20735, train_acc: 93.5% ,test_loss: 0.48483, test_acc: 81.6%
epoch:14, train_loss: 0.20245, train_acc: 93.9% ,test_loss: 0.46851, test_acc: 83.7%
epoch:15, train_loss: 0.13799, train_acc: 96.0% ,test_loss: 0.49551, test_acc: 84.0%
epoch:16, train_loss: 0.10263, train_acc: 97.3% ,test_loss: 0.52576, test_acc: 84.2%
epoch:17, train_loss: 0.08577, train_acc: 97.9% ,test_loss: 0.55930, test_acc: 84.4%
epoch:18, train_loss: 0.07269, train_acc: 98.3% ,test_loss: 0.59011, test_acc: 84.5%
epoch:19, train_loss: 0.06335, train_acc: 98.6% ,test_loss: 0.60821, test_acc: 84.3%
epoch:20, train_loss: 0.05706, train_acc: 98.7% ,test_loss: 0.59220, test_acc: 84.1%
epoch:21, train_loss: 0.05131, train_acc: 98.9% ,test_loss: 0.62850, test_acc: 84.5%
epoch:22, train_loss: 0.04518, train_acc: 99.0% ,test_loss: 0.62696, test_acc: 84.0%
epoch:23, train_loss: 0.04158, train_acc: 99.1% ,test_loss: 0.61498, test_acc: 84.6%
epoch:24, train_loss: 0.03753, train_acc: 99.2% ,test_loss: 0.63472, test_acc: 84.1%
epoch:25, train_loss: 0.03328, train_acc: 99.3% ,test_loss: 0.74586, test_acc: 84.6%
epoch:26, train_loss: 0.03213, train_acc: 99.3% ,test_loss: 0.70182, test_acc: 84.6%
epoch:27, train_loss: 0.02766, train_acc: 99.4% ,test_loss: 0.71342, test_acc: 84.6%
epoch:28, train_loss: 0.02531, train_acc: 99.4% ,test_loss: 0.76255, test_acc: 84.4%
epoch:29, train_loss: 0.02227, train_acc: 99.4% ,test_loss: 0.74665, test_acc: 84.7%
epoch:30, train_loss: 0.01811, train_acc: 99.5% ,test_loss: 0.79624, test_acc: 84.7%
epoch:31, train_loss: 0.01743, train_acc: 99.5% ,test_loss: 0.79401, test_acc: 84.7%
epoch:32, train_loss: 0.01700, train_acc: 99.5% ,test_loss: 0.77805, test_acc: 84.6%
epoch:33, train_loss: 0.01606, train_acc: 99.5% ,test_loss: 0.88634, test_acc: 84.8%
epoch:34, train_loss: 0.01522, train_acc: 99.6% ,test_loss: 0.86396, test_acc: 84.8%
epoch:35, train_loss: 0.01509, train_acc: 99.5% ,test_loss: 0.83588, test_acc: 84.8%
epoch:36, train_loss: 0.01447, train_acc: 99.6% ,test_loss: 0.89528, test_acc: 84.9%
epoch:37, train_loss: 0.01449, train_acc: 99.6% ,test_loss: 0.84509, test_acc: 84.7%
epoch:38, train_loss: 0.01412, train_acc: 99.6% ,test_loss: 0.86246, test_acc: 84.7%
epoch:39, train_loss: 0.01367, train_acc: 99.6% ,test_loss: 0.88045, test_acc: 84.7%
epoch:40, train_loss: 0.01302, train_acc: 99.6% ,test_loss: 0.87827, test_acc: 84.6%
epoch:41, train_loss: 0.01281, train_acc: 99.6% ,test_loss: 0.87976, test_acc: 84.7%
epoch:42, train_loss: 0.01269, train_acc: 99.6% ,test_loss: 0.90360, test_acc: 84.7%
epoch:43, train_loss: 0.01230, train_acc: 99.6% ,test_loss: 0.90518, test_acc: 84.8%
epoch:44, train_loss: 0.01219, train_acc: 99.6% ,test_loss: 0.88305, test_acc: 84.6%
epoch:45, train_loss: 0.01102, train_acc: 99.6% ,test_loss: 0.92137, test_acc: 84.7%
epoch:46, train_loss: 0.01103, train_acc: 99.7% ,test_loss: 0.93830, test_acc: 84.7%
epoch:47, train_loss: 0.01136, train_acc: 99.7% ,test_loss: 0.94346, test_acc: 84.7%
epoch:48, train_loss: 0.01122, train_acc: 99.7% ,test_loss: 0.94277, test_acc: 84.7%
epoch:49, train_loss: 0.01131, train_acc: 99.7% ,test_loss: 0.93391, test_acc: 84.6%
epoch:50, train_loss: 0.01118, train_acc: 99.7% ,test_loss: 0.93697, test_acc: 84.6%
epoch:51, train_loss: 0.01134, train_acc: 99.7% ,test_loss: 0.94353, test_acc: 84.6%
epoch:52, train_loss: 0.01143, train_acc: 99.7% ,test_loss: 0.95241, test_acc: 84.6%
epoch:53, train_loss: 0.01114, train_acc: 99.7% ,test_loss: 0.94443, test_acc: 84.6%
epoch:54, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.94568, test_acc: 84.6%
epoch:55, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.95337, test_acc: 84.6%
epoch:56, train_loss: 0.01099, train_acc: 99.7% ,test_loss: 0.96076, test_acc: 84.6%
epoch:57, train_loss: 0.01115, train_acc: 99.7% ,test_loss: 0.96316, test_acc: 84.6%
epoch:58, train_loss: 0.01116, train_acc: 99.7% ,test_loss: 0.95679, test_acc: 84.6%
epoch:59, train_loss: 0.01068, train_acc: 99.7% ,test_loss: 0.94736, test_acc: 84.5%
epoch:60, train_loss: 0.01125, train_acc: 99.7% ,test_loss: 0.94820, test_acc: 84.5%
epoch:61, train_loss: 0.01074, train_acc: 99.7% ,test_loss: 0.95048, test_acc: 84.6%
epoch:62, train_loss: 0.01092, train_acc: 99.7% ,test_loss: 0.95178, test_acc: 84.6%
epoch:63, train_loss: 0.01093, train_acc: 99.7% ,test_loss: 0.95187, test_acc: 84.6%
epoch:64, train_loss: 0.01095, train_acc: 99.7% ,test_loss: 0.95308, test_acc: 84.6%
epoch:65, train_loss: 0.01046, train_acc: 99.7% ,test_loss: 0.95353, test_acc: 84.6%
epoch:66, train_loss: 0.01062, train_acc: 99.7% ,test_loss: 0.95421, test_acc: 84.6%
epoch:67, train_loss: 0.01178, train_acc: 99.7% ,test_loss: 0.95447, test_acc: 84.6%
epoch:68, train_loss: 0.01068, train_acc: 99.7% ,test_loss: 0.95500, test_acc: 84.6%
epoch:69, train_loss: 0.01233, train_acc: 99.7% ,test_loss: 0.95466, test_acc: 84.6%
epoch:70, train_loss: 0.01081, train_acc: 99.7% ,test_loss: 0.95511, test_acc: 84.6%
epoch:71, train_loss: 0.01076, train_acc: 99.7% ,test_loss: 0.95512, test_acc: 84.6%
epoch:72, train_loss: 0.01071, train_acc: 99.7% ,test_loss: 0.95637, test_acc: 84.6%
epoch:73, train_loss: 0.01064, train_acc: 99.7% ,test_loss: 0.95661, test_acc: 84.6%
epoch:74, train_loss: 0.01066, train_acc: 99.7% ,test_loss: 0.95692, test_acc: 84.6%
epoch:75, train_loss: 0.01073, train_acc: 99.7% ,test_loss: 0.95691, test_acc: 84.6%
epoch:76, train_loss: 0.01051, train_acc: 99.7% ,test_loss: 0.95698, test_acc: 84.6%
epoch:77, train_loss: 0.01069, train_acc: 99.7% ,test_loss: 0.95695, test_acc: 84.6%
epoch:78, train_loss: 0.01095, train_acc: 99.7% ,test_loss: 0.95705, test_acc: 84.6%
epoch:79, train_loss: 0.01087, train_acc: 99.7% ,test_loss: 0.95708, test_acc: 84.6%
epoch:80, train_loss: 0.01084, train_acc: 99.7% ,test_loss: 0.95721, test_acc: 84.6%
epoch:81, train_loss: 0.01073, train_acc: 99.7% ,test_loss: 0.95725, test_acc: 84.6%
epoch:82, train_loss: 0.01055, train_acc: 99.7% ,test_loss: 0.95728, test_acc: 84.6%
epoch:83, train_loss: 0.01067, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:84, train_loss: 0.01059, train_acc: 99.7% ,test_loss: 0.95727, test_acc: 84.6%
epoch:85, train_loss: 0.01121, train_acc: 99.7% ,test_loss: 0.95731, test_acc: 84.6%
epoch:86, train_loss: 0.01114, train_acc: 99.7% ,test_loss: 0.95729, test_acc: 84.6%
epoch:87, train_loss: 0.01157, train_acc: 99.7% ,test_loss: 0.95726, test_acc: 84.6%
epoch:88, train_loss: 0.01057, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:89, train_loss: 0.01149, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:90, train_loss: 0.01070, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:91, train_loss: 0.01275, train_acc: 99.7% ,test_loss: 0.95722, test_acc: 84.6%
epoch:92, train_loss: 0.01057, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:93, train_loss: 0.01061, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:94, train_loss: 0.01045, train_acc: 99.7% ,test_loss: 0.95723, test_acc: 84.6%
epoch:95, train_loss: 0.01157, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:96, train_loss: 0.01182, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:97, train_loss: 0.01054, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:98, train_loss: 0.01092, train_acc: 99.7% ,test_loss: 0.95724, test_acc: 84.6%
epoch:99, train_loss: 0.01061, train_acc: 99.7% ,test_loss: 0.95725, test_acc: 84.6%
Done!
网友评论