全连接,卷积,循环,就差不多了吧。。强化,,不熟悉,暂不了解。
一,代码
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
class DataLoader():
def __init__(self):
path = tf.keras.utils.get_file('nietzsche.txt',
origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
with open(path, encoding='utf-8') as f:
self.raw_text = f.read().lower()
self.chars = sorted(list(set(self.raw_text)))
self.char_indices = dict((c, i) for i ,c in enumerate(self.chars))
self.indices_char = dict((i ,c) for i, c in enumerate(self.chars))
self.text = [self.char_indices[c] for c in self.raw_text]
def get_batch(self, seq_lengh, batch_size):
seq = []
next_char = []
for i in range(batch_size):
index = np.random.randint(0, len(self.text) - seq_lengh)
seq.append(self.text[index:index+seq_lengh])
next_char.append(self.text[index+seq_lengh])
return np.array(seq), np.array(next_char)
class RNN(tf.keras.Model):
def __init__(self, num_chars, batch_size, seq_length):
super().__init__()
self.num_chars = num_chars
self.seq_length = seq_length
self.batch_size = batch_size
self.cell = tf.keras.layers.LSTMCell(units=256)
self.dense = tf.keras.layers.Dense(units=self.num_chars)
def call(self, inputs, from_logits=False):
inputs = tf.one_hot(inputs, depth=self.num_chars)
state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
for t in range(self.seq_length):
output, state = self.cell(inputs[:, t, :], state)
logits = self.dense(output)
if from_logits:
return logits
else:
return tf.nn.softmax(logits)
def predict(self, inputs, temperature=1.):
batch_size, _ = tf.shape(inputs)
logits = self(inputs, from_logits=True)
prob = tf.nn.softmax(logits / temperature).numpy()
return np.array([np.random.choice(self.num_chars, p=prob[i, :])
for i in range(batch_size.numpy())])
num_batches = 1000
seq_length = 40
batch_size = 50
learning_rate = 1e-2
data_loader = DataLoader()
model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
X, y = data_loader.get_batch(seq_length, batch_size)
with tf.GradientTape() as tape:
y_pred = model(X)
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
loss = tf.reduce_mean(loss)
print('batch {}: loss {}'.format(batch_index, loss.numpy()))
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
X = X_
print('diversity {}'.format(diversity))
for t in range(400):
y_pred = model.predict(X, diversity)
print(data_loader.indices_char[y_pred[0]], end=' ', flush=True)
X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=1)
print('\n')
二,输出
C:\Users\ccc\AppData\Local\Programs\Python\Python38\python.exe D:/tmp/tup_ai/codes/2.clustering/kmeans/tf_test.py
2022-06-23 10:50:24.140688: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
batch 0: loss 4.046131134033203
batch 1: loss 3.9030046463012695
batch 2: loss 3.5824038982391357
batch 3: loss 3.6495604515075684
batch 4: loss 3.3191030025482178
batch 5: loss 3.4458301067352295
batch 6: loss 3.1989667415618896
batch 7: loss 3.3773651123046875
batch 8: loss 3.073580741882324
batch 9: loss 3.1812446117401123
batch 10: loss 3.3901772499084473
batch 11: loss 3.224184036254883
batch 12: loss 3.1514182090759277
batch 13: loss 3.001044988632202
batch 14: loss 2.9299585819244385
batch 15: loss 2.950603723526001
batch 16: loss 2.9749419689178467
batch 17: loss 3.2935285568237305
batch 18: loss 3.1480045318603516
batch 19: loss 3.1267688274383545
batch 20: loss 3.025820016860962
batch 21: loss 2.9815168380737305
batch 22: loss 3.1005260944366455
batch 23: loss 2.8049328327178955
batch 24: loss 2.9903664588928223
batch 25: loss 2.9383554458618164
batch 26: loss 3.0361881256103516
batch 27: loss 3.1392831802368164
batch 28: loss 3.0993568897247314
batch 29: loss 2.8119354248046875
batch 30: loss 3.322962999343872
batch 31: loss 3.0106008052825928
batch 32: loss 3.151655912399292
batch 33: loss 2.8608169555664062
batch 34: loss 2.972957134246826
batch 35: loss 3.1426093578338623
batch 36: loss 3.001406669616699
batch 37: loss 3.127439498901367
batch 38: loss 2.882235050201416
batch 39: loss 2.9103260040283203
batch 40: loss 3.003803014755249
batch 41: loss 2.795168876647949
batch 42: loss 3.0328474044799805
batch 43: loss 3.023397922515869
batch 44: loss 2.8765347003936768
batch 45: loss 2.903585433959961
batch 46: loss 2.609665632247925
batch 47: loss 3.223663330078125
batch 48: loss 2.8513782024383545
batch 49: loss 3.077909469604492
batch 50: loss 2.857851266860962
batch 51: loss 3.1217010021209717
Process finished with exit code -1
网友评论