美文网首页
PyTrch深度学习简明实战25 - GRU

PyTrch深度学习简明实战25 - GRU

作者: 薛东弗斯 | 来源:发表于2023-05-01 17:03 被阅读0次
image.png
image.png
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.em = nn.Embedding(20002, embeding_dim)   # 200*batch*100
        self.rnn = nn.GRU(embeding_dim, hidden_size)     # batch*300
        self.fc1 = nn.Linear(hidden_size, 256)
        self.fc2 = nn.Linear(256, 3)

    def forward(self, x):
        bz = x.shape[1]
        h0 = torch.zeros((1, bz, hidden_size)).cuda()
        x = self.em(x)
        o, _ = self.rnn(x)
        x = F.relu(self.fc1(o[-1]))
        x = self.fc2(x)
        return x

model = Net()

if torch.cuda.is_available():
    model.to('cuda')

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_iter,
                                                                 test_iter)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

相关文章

网友评论

      本文标题:PyTrch深度学习简明实战25 - GRU

      本文链接:https://www.haomeiwen.com/subject/oriojdtx.html