美文网首页
文本分类LSTM简单代码

文本分类LSTM简单代码

作者: AntiGravity | 来源:发表于2022-03-18 20:37 被阅读0次
#-*-coding:utf-8-*-
import jieba
import pandas as pd
import re

train = pd.read_csv('data/training.csv', encoding='utf8', header=None)
train.columns = ['label', 'data']
train_seged = train.copy()

test = pd.read_csv('data/testing.csv', encoding='utf8', header=None)
test.columns = ['id', 'data']
test_seged = test.copy()
# 获取停用词
def getStopwords():
    stopwords = []
    with open("data/stopwords.txt",  encoding='utf8') as f:
        lines = f.readlines()
        for line in lines:
            stopwords.append(line.strip())
    return stopwords
stopwords = getStopwords()

def seg(text):
    # 分词
    words = jieba.cut(text)
    # 去停用词
    words = [word for word in words if word not in stopwords]
    # 转换为文本后去除字母数字特殊字符
    result=re.sub(r'[\x21-\x7e]|[a-zA-Z\d]','',' '.join(words)) #去英文数字及英文符号
    result=re.findall(r'\w+',result)             #得到去除中文符号的中字列表
#     print(result)
    return ' '.join(result)

train['data'] = train['data'].map(seg)
test['data'] = test['data'].map(seg)

train.to_csv('output/train.csv',index=False, encoding = 'utf-8')
test.to_csv('output/test.csv',index=False, encoding = 'utf-8')

import tensorflow as tf

tlist=train['data'].tolist()
tlabel=[i-1 for i in train['label'].tolist()]
tv=tf.keras.layers.TextVectorization()
tv.adapt(tlist)

model=tf.keras.Sequential([
    tv,
    tf.keras.layers.Embedding(len(tv.get_vocabulary()),64),
    tf.keras.layers.LSTM(64),
    tf.keras.layers.Dense(11,activation='softmax')
])
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=tf.keras.optimizers.Adam(),
)

for i in range(15):
    model.fit(tlist,tlabel)

from sklearn.metrics import accuracy_score
preds=model.predict(tlist)
accuracy_score(tlabel, preds.argmax(-1))

相关文章

网友评论

      本文标题:文本分类LSTM简单代码

      本文链接:https://www.haomeiwen.com/subject/eujwdrtx.html