美文网首页👨‍💻‍桂圆|技术文摘
中文NLP笔记:11. 基于 LSTM 生成古诗

中文NLP笔记:11. 基于 LSTM 生成古诗

作者: 不会停的蜗牛 | 来源:发表于2019-02-08 23:08 被阅读108次

    基于 LSTM 生成古诗

    1. 语料准备

      一共四万多首古诗,一行一首诗

    2. 预处理

      将汉字表示为 One-Hot 的形式

      在每行末尾加上 ] 符号是为了标识这首诗已经结束,说明 ] 符号之前的语句和之后的语句是没有关联关系的,后面会舍弃掉包含 ] 符号的训练数据。

          puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》']

        def preprocess_file(Config):

            # 语料文本内容

            files_content = ''

            with open(Config.poetry_file, 'r', encoding='utf-8') as f:

                for line in f:

                    # 每行的末尾加上"]"符号代表一首诗结束

                    for char in puncs:

                        line = line.replace(char, "")

                    files_content += line.strip() + "]"

            words = sorted(list(files_content))

            words.remove(']')

            counted_words = {}

            for word in words:

                if word in counted_words:

                    counted_words[word] += 1

                else:

                    counted_words[word] = 1

            # 去掉低频的字

            erase = []

            for key in counted_words:

                if counted_words[key] <= 2:

                    erase.append(key)

            for key in erase:

                del counted_words[key]

            del counted_words[']']

            wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])

            words, _ = zip(*wordPairs)

            # word到id的映射

            word2num = dict((c, i + 1) for i, c in enumerate(words))

            num2word = dict((i, c) for i, c in enumerate(words))

            word2numF = lambda x: word2num.get(x, 0)

            return word2numF, num2word, words, files_content

    3. 模型参数配置

      class Config(object):

        poetry_file = 'poetry.txt'

        weight_file = 'poetry_model.h5'

        # 根据前六个字预测第七个字

        max_len = 6

        batch_size = 512

        learning_rate = 0.001

    4. 构建模型

      通过 PoetryModel 类实现

          class PoetryModel(object):

            def __init__(self, config):

                pass

            def build_model(self):

                pass

            def sample(self, preds, temperature=1.0):

                pass

            def generate_sample_result(self, epoch, logs):

                pass

            def predict(self, text):

                pass

            def data_generator(self):

                pass

            def train(self):

                pass

      (1)init 函数

      加载 Config 配置信息,进行语料预处理和模型加载

          def __init__(self, config):

                self.model = None

                self.do_train = True

                self.loaded_model = False

                self.config = config

                # 文件预处理

                self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)

                if os.path.exists(self.config.weight_file):

                    self.model = load_model(self.config.weight_file)

                    self.model.summary()

                else:

                    self.train()

                self.do_train = False

                self.loaded_model = True

      (2)build_model 函数

      GRU 模型建立

          def build_model(self):

                '''建立模型'''

                input_tensor = Input(shape=(self.config.max_len,))

                embedd = Embedding(len(self.num2word)+1, 300, input_length=self.config.max_len)(input_tensor)

                lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

                dropout = Dropout(0.6)(lstm)

                lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

                dropout = Dropout(0.6)(lstm)

                flatten = Flatten()(lstm)

                dense = Dense(len(self.words), activation='softmax')(flatten)

                self.model = Model(inputs=input_tensor, outputs=dense)

                optimizer = Adam(lr=self.config.learning_rate)

                self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

      (3)sample 函数

          def sample(self, preds, temperature=1.0):

            preds = np.asarray(preds).astype('float64')

            preds = np.log(preds) / temperature

            exp_preds = np.exp(preds)

            preds = exp_preds / np.sum(exp_preds)

            probas = np.random.multinomial(1, preds, 1)

            return np.argmax(probas)

      (4)训练模型

          def generate_sample_result(self, epoch, logs): 

                print("\n==================Epoch {}=====================".format(epoch))

                for diversity in [0.5, 1.0, 1.5]:

                    print("------------Diversity {}--------------".format(diversity))

                    start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)

                    generated = ''

                    sentence = self.files_content[start_index: start_index + self.config.max_len]

                    generated += sentence

                    for i in range(20):

                        x_pred = np.zeros((1, self.config.max_len))

                        for t, char in enumerate(sentence[-6:]):

                            x_pred[0, t] = self.word2numF(char)

                        preds = self.model.predict(x_pred, verbose=0)[0]

                        next_index = self.sample(preds, diversity)

                        next_char = self.num2word[next_index]

                        generated += next_char

                        sentence = sentence + next_char

                    print(sentence)

      (5)predict 函数

      根据给出的文字,生成诗句

          def predict(self, text):

                if not self.loaded_model:

                    return

                with open(self.config.poetry_file, 'r', encoding='utf-8') as f:

                    file_list = f.readlines()

                random_line = random.choice(file_list)

                # 如果给的text不到四个字,则随机补全

                if not text or len(text) != 4:

                    for _ in range(4 - len(text)):

                        random_str_index = random.randrange(0, len(self.words))

                        text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。',

                                                                                                                  ','] else self.num2word.get(

                            random_str_index + 1)

                seed = random_line[-(self.config.max_len):-1]

                res = ''

                seed = 'c' + seed

                for c in text:

                    seed = seed[1:] + c

                    for j in range(5):

                        x_pred = np.zeros((1, self.config.max_len))

                        for t, char in enumerate(seed):

                            x_pred[0, t] = self.word2numF(char)

                        preds = self.model.predict(x_pred, verbose=0)[0]

                        next_index = self.sample(preds, 1.0)

                        next_char = self.num2word[next_index]

                        seed = seed[1:] + next_char

                    res += seed

                return res

      (6) data_generator 函数

      生成数据,提供给模型训练时使用

            def data_generator(self):

                i = 0

                while 1:

                    x = self.files_content[i: i + self.config.max_len]

                    y = self.files_content[i + self.config.max_len]

                    puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》', ':']

                    if len([i for i in puncs if i in x]) != 0:

                        i += 1

                        continue

                    if len([i for i in puncs if i in y]) != 0:

                        i += 1

                        continue

                    y_vec = np.zeros(

                        shape=(1, len(self.words)),

                        dtype=np.bool

                    )

                    y_vec[0, self.word2numF(y)] = 1.0

                    x_vec = np.zeros(

                        shape=(1, self.config.max_len),

                        dtype=np.int32

                    )

                    for t, char in enumerate(x):

                        x_vec[0, t] = self.word2numF(char)

                    yield x_vec, y_vec

                    i += 1

      (7)train 函数

          def train(self):

                #number_of_epoch = len(self.files_content) // self.config.batch_size

                number_of_epoch = 10

                if not self.model:

                    self.build_model()

                self.model.summary()

                self.model.fit_generator(

                    generator=self.data_generator(),

                    verbose=True,

                    steps_per_epoch=self.config.batch_size,

                    epochs=number_of_epoch,

                    callbacks=[

                        keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),

                        LambdaCallback(on_epoch_end=self.generate_sample_result)

                    ]

                )

    5. 进行模型训练

      model = PoetryModel(Config)

    6. 作诗

          text = input("text:")

        sentence = model.predict(text)

        print(sentence)


    学习资料:

    《中文自然语言处理入门实战》

    相关文章

      网友评论

        本文标题:中文NLP笔记:11. 基于 LSTM 生成古诗

        本文链接:https://www.haomeiwen.com/subject/nntasqtx.html