使用Transformer进行机器翻译的构建

1. 加载数据通过torchtext

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from typing import Iterable, List# 源语言是德语
SRC_LANGUAGE = 'de'
# 目标语言是英语
TGT_LANGUAGE = 'en'# 定义token的字典, 定义vocab字典
token_transform = {}
vocab_transform = {}# 创建源语言和目标语言的kokenizer, 确保依赖关系已经安装
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm
# get_tokenizer是分词函数, 如果没有特殊的则按照英语的空格分割, 如果有这按照对应的分词库返回. 比如spacy, 返回对应的分词库
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

2. 生成token值的辅助函数

def yield_tokens(data_iter: Iterable, language: str) -> List[str]:# data_iter: 对象的迭代对象 Multi30k对象# language: 对应的翻译语言 {'de': 0, 'en': 1}language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}# 返回对应的数据迭代器对象for data_sample in data_iter:# data_sample:(德文, 英文)# data_sample:('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n', 'Two young, White males are outside near many bushes.\n')# token_transform['de']()=['Zwei', 'junge', 'weiße', 'Männer', 'sind', 'im', 'Freien', 'in', 'der', 'Nähe', 'vieler', 'Büsche', '.', '\n']# or  token_transform['en']分别进行构造对应的字典yield token_transform[language](data_sample[language_index[language]])

3. 利用循环生成对应的数据集的语言句子对

# 定义特殊字符及其对应的索引值
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# 确保标记按其索引的顺序正确插入到词汇表中
special_symbols = ['', '', '', '']for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:# 训练数据集的迭代器,# 数据集是用英文描述图像的英文语句, 然后人工将其翻译为德文的语句,有两个文件, 一个是train.de 一个是train.en文件,# 然后将其构建为(德文, 英文)的形式train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))# 创建torchtext的vocab对象, 即词汇表vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln), # 用于构建 Vocab 的迭代器。必须产生令牌列表或迭代器min_freq=1,#在词汇表中包含一个标记所需的最低频率specials=special_symbols, # 用于添加的特殊字符special_first=True) # 指示是在开头还是结尾插入符号# 将 UNK_IDX 设置为默认索引。未找到令牌时返回此索引
# 如果未设置,则在 Vocabulary 中找不到查询的标记时抛出 RuntimeError
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:vocab_transform[ln].set_default_index(UNK_IDX)

4. 使用Transformer架构构建模型

4.1 构建位置编码器类

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):def __init__(self,emb_size: int,dropout: float, maxlen: int = 5000):'''emb_size: 词嵌入的维度大小dropout: 正则化的大小maxlen: 句子的最大长度'''super(PositionalEncoding, self).__init__()# 将1000的2i/d_model变型为e的指数形式den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)# 效果等价与torch.arange(0, maxlen).unsqueeze(1)pos = torch.arange(0, maxlen).reshape(maxlen, 1)# 构建一个(maxlen, emb_size)大小的全零矩阵pos_embedding = torch.zeros((maxlen, emb_size))# 偶数列是正弦函数填充pos_embedding[:, 0::2] = torch.sin(pos * den)# 奇数列是余弦函数填充


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部