BERT中文任务实战(文本分类、预测下一句)踩坑记录

文章目录

  • 一、概述
  • 二、Bert简介
        • 2.1 简要说明
        • 2.2 fine-tune原理
  • 三、在项目数据集上fine-tune教程
        • 3.1整体流程
        • 3.2 自定义DataProcessor
        • 3.3 参数设置
        • 3.4 预测函数
  • 四、踩坑记录
  • 五、参考文献

一、概述

最近参与了两个项目开发,主要内容是基于NLP和深度学习的文本处理任务。其中有两个子模块用到文本分类和预测下一句模型,刚好前段时间自己折腾学习了一点BERT,就打算实践一下,记录下遇到的问题和解决方案。

二、Bert简介

2.1 简要说明

关于BERT的介绍,已经有铺天盖地的文章和博客或深或浅的介绍了,这里暂时先不展开详细介绍,后面会写一篇论文翻译和理解的博客,这里暂时简要带过,这里我们只简单说明下为什么我们选择在自己数据集上fine-tuning,然后用于预测。

2.2 fine-tune原理

在BERT论文中,作者说明了BERT的fine-tune原理。

BERT模型首先会对input进行编码,转为模型需要的编码格式,使用辅助标记符[CLS]和[SEP]来表示句子的开始和分隔。然后根据输入得到对应的embedding,这里的embedding是三种embedding的和,分别是token、segment、position级别。
input representation

得到整体的embedding后即使用相关模型进行学习,最终根据不同任务的分类层得到结果。
在这里插入图片描述

图a表示句子对的分类任务,例如预测下一句、语义相似度等任务,输入是两个句子A和B,中间用[SEP]分隔,最终得到的class label就表示是否下一句或者是否是语义相似的。

图b表示单句分类任务,如常见的文本分类、情感分析等。输入就是一个单独的句子,最终的class label就是表示句子属于哪一类,或者属于什么情感。

图c表示问答任务,主要用于SQuAD数据集,输入是一个问题和问题对应的段落,用[SEP]分隔,这里输出的结果就不是某个class label而是答案在给定段落的开始和终止位置,主要用于阅读理解任务。

图d表示单个句子标注任务,例如常见的命名实体识别任务,输入就是一个单独的句子,输出是句子中每个token对应的类别标注。

这里我主要是使用了前两个任务的方法,分别fine-tune了预测下一句模型和单句分类模型。

三、在项目数据集上fine-tune教程

因为之前主要使用的框架是Pytorch,因此fine-tune的代码也主要参考了pytorch版的bert复现代码。

先做的是文本分类任务,主要参考了examples/run_classifier.py,可以看到,整个脚本代码有接近1000行,但是其实在我们的fine-tune过程中,需要理解的关键部分主要是DataProcessor类。我们的fine-tune说的直白一点,就是把我们自己的数据整理好,转换成BERT模型能够读取的输入,只要模型读到了inputs,后续的各种内部转换表示其实已经不需要我们关注了(但是最好还是要理解,学习最牛的模型的原理和思路)。

3.1整体流程

首先简要介绍下fine-tune的整体流程,如下图所示:

在这里插入图片描述

我们首先需要先进行一些预处理,就是把训练集、验证集、测试集标签化。接着会调用我们自定义的继承DataProcessor类的MyPro类,这一步就是实现将我们的训练数据,转换成模型能够获取的标准输入格式。这里是转换成论文定义的一种InputExamples格式,相关代码:

class InputExample(object):"""A single training/test example for simple sequence classification."""def __init__(self, guid, text_a, text_b=None, label=None):"""生成一个InputExample.Args:guid: 每个example的独有id.text_a: 字符串,也就是输入的未分割的句子A,对于单句分类任务来说,text_a是必须有的text_b: 可选的输入字符串,单句分类任务不需要这一项,在预测下一句或者阅读理解任务中需要输入text_b,text_a和text_b中间使用[SEP]分隔label: 也是可选的字符串,就是对应的文本或句子的标签,在训练和验证的时候需要指定,但是在测试的时候可以不选"""self.guid = guidself.text_a = text_aself.text_b = text_bself.label = label

接着会调用convert_examples_to_features()将所有的InputExamples转为一种train_features格式,相关代码:

class InputFeatures(object):"""A single set of features of data.Args:input_ids:  token的id,在chinese模式中就是每个分词的id,对应一个word vector,就是之前提到的混合embeddinginput_mask:  真实字符对应1,补全字符对应0,在padding的时候可能会补0,需要记录一下真实的输入字符,模型的attention机制只关注这些字符segment_ids:  句子标识符,第一句全为0,第二句全为1,主要是用于区分单句任务或者是句子对任务,但其实我们通过使用[SEP]已经起到了区分作用,这里主要还是为了便于模型识别计算labe


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部