【动手学习pytorch笔记】37.4 BERT微调数据集

BERT微调数据集

自然语言推断任务:

主要研究 假设(hypothesis)是否可以从前提(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

  • 蕴涵(entailment):假设可以从前提中推断出来。
  • 矛盾(contradiction):假设的否定可以从前提中推断出来。
  • 中性(neutral):所有其他情况。

斯坦福自然语言推断(SNLI)数据集

由500000多个带标签的英语句子对组成的集合

import os
import re
import torch
from torch import nn
from d2l import torch as d2l#@saved2l.DATA_HUB['SNLI'] = ('https://nlp.stanford.edu/projects/snli/snli_1.0.zip','9fcde07509c7e87ec61c640c1b2753d9041758e4')data_dir = "D:\environment\data\data\snli_1.0"

读取数据集

#@save
def read_snli(data_dir, is_train):"""将SNLI数据集解析为前提、假设和标签"""def extract_text(s):# 删除我们不会使用的信息s = re.sub('\\(', '', s)s = re.sub('\\)', '', s)# 用一个空格替换两个或多个连续的空格s = re.sub('\\s{2,}', ' ', s)return s.strip()label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}file_name = os.path.join(data_dir, 'snli_1.0_train.txt'if is_train else 'snli_1.0_test.txt')with open(file_name, encoding = 'utf-8') as f:rows = [row.split('\t') for row in f.readlines()[1:]]premises = [extract_text(row[1]) for row in rows if row[0] in label_set]hypotheses = [extract_text(row[2]) for row in rows if row[0] \in label_set]labels = [label_set[row[0]] for row in rows if row[0] in label_set]return premises, hypotheses, labels
train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):print('前提:', x0)print('假设:', x1)print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0

统计三个关系的数量

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

上面训练集,下面测试集,挺平均的

data[2]是标签,统计标签数量就行

#@save
class SNLIDataset(torch.utils.data.Dataset):"""用于加载SNLI数据集的自定义数据集"""def __init__(self, dataset, num_steps, vocab=None):self.num_steps = num_stepsall_premise_tokens = d2l.tokenize(dataset[0])all_hypothesis_tokens = d2l.tokenize(dataset[1])if vocab is None:self.vocab = d2l.Vocab(all_premise_tokens + \all_hypothesis_tokens, min_freq=5, reserved_tokens=[''])else:self.vocab = vocabself.premises = self._pad(all_premise_tokens)self.hypotheses = self._pad(all_hypothesis_tokens)self.labels = torch.tensor(dataset[2])print('read ' + str(len(self.premises)) + ' examples')def _pad(self, lines):return torch.tensor([d2l.truncate_pad(self.vocab[line], self.num_steps, self.vocab[''])for line in lines])def __getitem__(self, idx):return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]def __len__(self):return len(self.premises)

vocab需要与BERT预训练时的vocab保持一致,不然他不认识呀,所以下载与训练模型的时候一般都是下载模型和vocab

整理一下

#@savedef load_data_snli(batch_size, num_steps=50):"""下载SNLI数据集并返回数据迭代器和词表"""num_workers = d2l.get_dataloader_workers()data_dir = "D:\environment\data\data\snli_1.0"train_data = read_snli(data_dir, True)test_data = read_snli(data_dir, False)train_set = SNLIDataset(train_data, num_steps)test_set = SNLIDataset(test_data, num_steps, train_set.vocab)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True)test_iter = torch.utils.data.DataLoader(test_set, batch_size,shuffle=False)return train_iter, test_iter, train_set.vocab

看看大小

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples18678
for X, Y in train_iter:print(X[0].shape)print(X[1].shape)print(Y.shape)break
torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])

batchsize = 128

一个句子长度为50

最后说一下这一节的踩坑:

  • 和之前一节的数据加载一样

    torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers = num_workers )

    num_workers = num_workers,开多线程读取数据会报错

  • 下载解压数据集时报错

    OSError: [Errno 22] Invalid argument: '..\\data\\snli_1.0\\Icon\r
    

    报错的原因:SNLI数据集的压缩文件"snli_1.0.zip"里面有两个路径为“snli_1.0\Icon\r”和“’__MACOSX/snli_1.0/._Icon\r’”的文件,导致无法解析此路径进而导致整个文件无法解压。

    解决方法:手动解压之后把data_dir赋值为数据集解压后的路径

    data_dir = d2l.download_extract('SNLI')
    

    改成

    data_dir = "D:\environment\data\data\snli_1.0"
    


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部