keras的模型处理以及基于h5做推理过程

把用keras自己搭建的模型训练后保存为有图结构的h5模型,然后基于h5去做推理预测全过程详解

我是基于苏神的摘要生成代码,做长难句压缩任务,自己对代码进行了部分修改,然后训练的过程没有切换成tf,所以整个框架是keras的,训练成的模型文件是h5形式的。我的整个过程就是先保存为带有图结构的h5模型,然后根据h5进行了一遍推理,然后将h5转为pb,然后根据pb进行了一遍推理。这里就先讲一下依据h5进行推理的过程。

第一步:保存模型文件为.h5
保存模型文件有三种形式,一种是只保存图结构(训练前就保存),一种是保留图结构和参数,一种是只保留参数,对应的函数分别是:

model.save('m1.h5')
model.save('m2.h5')
model.save_weights('m3.h5')

大小方面:m1最小,其次m3,m2最大。

对应的模型加载方式也是有区别的,如果全部是keras自带的那些API搭建的模型那么m1,m2的加载采用以下方式:

from keras.models import load_model
model = load_model('m1.h5')
model = load_model('m2.h5')

但是像我的模型里还有一些别的自定义测层,所以需要将这些特殊的自定义的层进行声明,并创建实例,如下:

custom_objects = {'Embedding': Embedding,'BiasAdd': BiasAdd,'MultiHeadAttention': MultiHeadAttention,'LayerNormalization': LayerNormalization,'PositionEmbedding': PositionEmbedding,'FeedForward': FeedForward,'Loss': Loss,'CrossEntropy': CrossEntropy,'Unilm': Unilm,'gelu_erf': gelu_erf,'gelu_tanh': gelu_tanh,'gelu': gelu_erf,}
model = load_model("model_weight.h5",custom_objects)

前提先找到自己的这些自定义层在哪儿,然后把这些层给import进去,就能正常加载了。
然后进行推理,因为我做的文本摘要生成,所以主要代码如下:

vocab_path='./pre_model/chinese_wobert_L-12_H-768_A-12/vocab.txt'class Decoder(AutoRegressiveDecoder):"""seq2seq解码器"""@AutoRegressiveDecoder.set_rtype('probas')def predict(self, inputs, output_ids, step):#pdb.set_trace()token_ids, segment_ids = inputstoken_ids = np.concatenate([token_ids, output_ids], 1)segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)logits = self.model.predict([token_ids, segment_ids])[:, -1]if self.second_model:second_logits = self.second_model.predict([token_ids, segment_ids])[:, -1]return (logits + second_logits)/2return logitsdef generate(self, text, tokenizer, topk=2):max_c_len = self.encode_max_len - self.maxlen#pdb.set_trace()token_ids, segment_ids = tokenizer.encode(text, max_length=max_c_len)output_ids = self.beam_search([token_ids, segment_ids], topk)  # 基于beam searchreturn tokenizer.decode(output_ids)model = load_model('./checkpoint/model/best_model.h5',custom_objects)
token_dict, keep_tokens = load_vocab(dict_path=vocab_path,simplified=True,startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],)
tokenizer = Tokenizer(token_dict, do_lower_case=True)
Summary = Decoder(start_id=None, end_id=tokenizer._token_end_id,encode_max_len=256,decode_max_len=32 , model=model)
test_sentence = input('enter sentence:')
generate = Summary.generate(test_sentence.strip(), tokenizer, topk=3).strip()
print(generate)

最后运行这个文件即可。


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部