利用Fast-Text进行中文文本分类

第一步:安装fasttext

可以参考:官方安装

liunx版本下操作:
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .安装成功后的导入:
新建test.py文件,写入:
import fastText.FastText as fasttext(可能会瞟红线)保存后退出并运行:
python3 test.py
没报错说明安装成功

第二步:准备数据集

  • 我这里用的是清华的新闻数据集(由于完整数据集较大,这里只取部分数据)
  • 数据链接:点击获取网盘数据 提取码:b3vd(data.txt为数据集,stopwords.txt为停用词)
  • 下载好后的数据格式为:
    微信截图_20190513160240.png
  • 对应的标签分别为(由于只是用小部分数据,所以data.txt只包含部分标签):
mapper_tag = {'财经': 'Finance','彩票': 'Lottery','房产': 'Property','股票': 'Shares','家居': 'Furnishing','教育': 'Education','科技': 'Technology','社会': 'Sociology','时尚': 'Fashion','时政': 'Affairs','体育': 'Sports','星座': 'Constellation','游戏': 'Game','娱乐': 'Entertainment'
}

第三步:数据预处理

  • 由于data.txt已经经过了分词和去停用词的处理,所以这里只需要对数据进行切割为训练集和测试集即可。
  • 分词和去停用词的工具代码(运行时不需要执行此部分代码):
import re
from types import MethodType, FunctionTypeimport jiebadef clean_txt(raw):fil = re.compile(r"[^0-9a-zA-Z\u4e00-\u9fa5]+")return fil.sub(' ', raw)def seg(sentence, sw, apply=None):if isinstance(apply, FunctionType) or isinstance(apply, MethodType):sentence = apply(sentence)return ' '.join([i for i in jieba.cut(sentence) if i.strip() and i not in sw])def stop_words():with open('stop_words.txt', 'r', encoding='utf-8') as swf:return [line.strip() for line in swf]# 对某个sentence进行处理:
content = '上海天然橡胶期价周三再创年内新高,主力合约突破21000元/吨重要关口。'
res = seg(content.lower().replace('\n', ''), stop_words(), apply=clean_txt)
  • 切割数据(这里我是先将txt文件转换成csv文件,方便后面的计算)
from random import shuffleimport pandas as pdclass _MD(object):mapper = {str: '',int: 0,list: list,dict: dict,set: set,bool: False,float: .0}def __init__(self, obj, default=None):self.dict = {}assert obj in self.mapper, \'got a error type'self.t = objif default is None:returnassert isinstance(default, obj), \f'default ({default}) must be {obj}'self.v = defaultdef __setitem__(self, key, value):self.dict[key] = valuedef __getitem__(self, item):if item not in self.dict and hasattr(self, 'v'):self.dict[item] = self.vreturn self.velif item not in self.dict:if callable(self.mapper[self.t]):self.dict[item] = self.mapper[self.t]()else:self.dict[item] = self.mapper[self.t]return self.dict[item]return self.dict[item]def defaultdict(obj, default=None):return _MD(obj, default)class TransformData(object):def to_csv(self, handler, output, index=False):dd = defaultdict(list)for line in handler:label, content = line.split(',', 1)dd[label.strip('__label__').strip()].append(content.strip())df = pd.DataFrame()for key in dd.dict:col = pd.Series(dd[key], name=key)df = pd.concat([df, col], axis=1)return df.to_csv(output, index=index, encoding='utf-8')def split_train_test(source, auth_data=False):if not auth_data:train_proportion = 0.8else:train_proportion = 0.98basename = source.rsplit('.', 1)[0]train_file = basename + '_train.txt'test_file = basename + '_test.txt'handel = pd.read_csv(source, index_col=False, low_memory=False)train_data_set = []test_data_set = []for head in list(handel.head()):train_num = int(handel[head].dropna().__len__() * train_proportion)sub_list = [f'__label__{head} , {item.strip()}\n' for item in handel[head].dropna().tolist()]train_data_set.extend(sub_list[:train_num])test_data_set.extend(sub_list[train_num:])shuffle(train_data_set)shuffle(test_data_set)with open(train_file, 'w', encoding='utf-8') as trainf,\open(test_file, 'w', encoding='utf-8') as testf:for tds in train_data_set:trainf.write(tds)for i in test_data_set:testf.write(i)return train_file, test_file# 转化成csv
td = TransformData()
handler = open('data.txt')
td.to_csv(handler, 'data.csv')
handler.close()# 将csv文件切割,会生成两个文件(data_train.txt和data_test.txt)
train_file, test_file = split_train_test('data.csv', auth_data=True)

第四步:训练模型

import fastText.FastText as fasttextdef train_model(ipt=None, opt=None, model='', dim=100, epoch=5, lr=0.1, loss='softmax'):np.set_printoptions(suppress=True)if os.path.isfile(model):classifier = fasttext.load_model(model)else:classifier = fasttext.train_supervised(ipt, label='__label__', dim=dim, epoch=epoch,lr=lr, wordNgrams=2, loss=loss)"""训练一个监督模型, 返回一个模型对象@param input:           训练数据文件路径@param lr:              学习率@param dim:             向量维度@param ws:              cbow模型时使用@param epoch:           次数@param minCount:        词频阈值, 小于该值在初始化时会过滤掉@param minCountLabel:   类别阈值,类别小于该值初始化时会过滤掉@param minn:            构造subword时最小char个数@param maxn:            构造subword时最大char个数@param neg:             负采样@param wordNgrams:      n-gram个数@param loss:            损失函数类型, softmax, ns: 负采样, hs: 分层softmax@param bucket:          词扩充大小, [A, B]: A语料中包含的词向量, B不在语料中的词向量@param thread:          线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出@param lrUpdateRate:    学习率更新@param t:               负采样阈值@param label:           类别前缀@param verbose:         ??@param pretrainedVectors: 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机@return model object"""classifier.save_model(opt)return classifierdim = 100
lr = 5
epoch = 5
model = f'data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'classifier = train_model(ipt='data_train.txt',opt=model,model=model,dim=dim, epoch=epoch, lr=0.5)result = classifier.test('data_test.txt')
print(result)# 整体的结果为(测试数据量,precision,recall):
(9885, 0.9740010116337886, 0.9740010116337886)
  • 可以看出结果相当高,由于上面是将整体作为测试,fasttext只给出整体的结果,precision和recall是相同的,下面我们测试每个标签的precision、recall和F1值。
def cal_precision_and_recall(file='data_test.txt'):precision = defaultdict(int, 1)recall = defaultdict(int, 1)total = defaultdict(int, 1)with open(file) as f:for line in f:label, content = line.split(',', 1)total[label.strip().strip('__label__')] += 1labels2 = classifier.predict([seg(sentence=content.strip(), sw='', apply=clean_txt)])pre_label, sim = labels2[0][0][0], labels2[1][0][0]recall[pre_label.strip().strip('__label__')] += 1if label.strip() == pre_label.strip():precision[label.strip().strip('__label__')] += 1print('precision', precision.dict)print('recall', recall.dict)print('total', total.dict)for sub in precision.dict:pre = precision[sub] / total[sub]rec =  precision[sub] / recall[sub]F1 = (2 * pre * rec) / (pre + rec)print(f"{sub.strip('__label__')}  precision: {str(pre)}  recall: {str(rec)}  F1: {str(F1)}")
  • 结果:
precision {'Technology': 983, 'Education': 972, 'Shares': 988, 'Affairs': 975, 'Entertainment': 991, 'Financ': 982, 'Furnishing': 975, 'Gam': 841, 'Sociology': 946, 'Sports': 978}
recall {'Technology': 992, 'Education': 1013, 'Shares': 1007, 'Affairs': 995, 'Entertainment': 1022, 'Financ': 1001, 'Furnishing': 997, 'Gam': 854, 'Sociology': 1025, 'Sports': 989}
total {'Technology': 1001, 'Education': 1001, 'Shares': 1001, 'Affairs': 1001, 'Entertainment': 1001, 'Financ': 1001, 'Furnishing': 1001, 'Gam': 876, 'Sociology': 1001, 'Sports': 1001, 'Property': 11}Technology  precision: 0.9820179820179821  recall: 0.9909274193548387  F1: 0.9864525840441545
Education  precision: 0.971028971028971  recall: 0.9595261599210266  F1: 0.9652432969215492
Shares  precision: 0.987012987012987  recall: 0.9811320754716981  F1: 0.9840637450199202
Affairs  precision: 0.974025974025974  recall: 0.9798994974874372  F1: 0.9769539078156312
Entertainment  precision: 0.99000999000999  recall: 0.9696673189823874  F1: 0.9797330696984675
Financ  precision: 0.981018981018981  recall: 0.981018981018981  F1: 0.981018981018981
Furnishing  precision: 0.974025974025974  recall: 0.9779338014042126  F1: 0.975975975975976
Gam  precision: 0.9600456621004566  recall: 0.9847775175644028  F1: 0.9722543352601155
Sociology  precision: 0.945054945054945  recall: 0.9229268292682927  F1: 0.9338598223099703
Sports  precision: 0.977022977022977  recall: 0.9888776541961577  F1: 0.9829145728643216

可以看出结果非常可观,fasttext很强大…

整合后的代码:

def main(source):basename = source.rsplit('.', 1)[0]csv_file = basename + '.csv'td = TransformData()handler = open(source)td.to_csv(handler, csv_file)handler.close()train_file, test_file = split_train_test(csv_file)dim = 100lr = 5epoch = 5model = f'data/data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'classifier = train_model(ipt=train_file,opt=model,model=model,dim=dim, epoch=epoch, lr=0.5)result = classifier.test(test_file)print(result)cal_precision_and_recall(test_file)if __name__ == '__main__':main('data.txt')


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部