TensorRTx-YOLOv5工程解读(一)
TensorRTx-YOLOv5工程解读(一)
权重生成:gen_wts.py
作者先是使用了gen_wts.py这个脚本去生成wts文件。顾名思义,这个.wts文件里面存放的就是.pt文件的权重。脚本内容如下:
import sys
import argparse
import os
import struct
import torch
from utils.torch_utils import select_devicedef parse_args():parser = argparse.ArgumentParser(description='Convert .pt file to .wts')parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')args = parser.parse_args()if not os.path.isfile(args.weights):raise SystemExit('Invalid input file')if not args.output:args.output = os.path.splitext(args.weights)[0] + '.wts'elif os.path.isdir(args.output):args.output = os.path.join(args.output,os.path.splitext(os.path.basename(args.weights))[0] + '.wts')return args.weights, args.outputpt_file, wts_file = parse_args()# Initialize
device = select_device('cpu')
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
model.to(device).eval()with open(wts_file, 'w') as f:f.write('{}\n'.format(len(model.state_dict().keys())))for k, v in model.state_dict().items():vr = v.reshape(-1).cpu().numpy()f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f' ,float(vv)).hex())f.write('\n')
第一个函数parse_args()就是正常处理输入的命令行参数,不多做赘述。
主函数内,先是设置设备为CPU,再load进pt文件获得model并转成FP32格式。并设置模型的device和eval模式。
设置完毕后,作者保存权重文件,其中权重文件的内容是作者自定义的。第一行存入的是model的keys的个数,再分别遍历pt文件内的每一个权重,保存为该层名称 该层参数量 16进制权重。
权重读取:common.cpp
首先顺着之前的思路,看看作者是如何load权重的。
// TensorRT weight files have a simple space delimited format:
// [type] [size]
std::map<std::string, Weights> loadWeights(const std::string file) {std::cout << "Loading weights: " << file << std::endl;std::map<std::string, Weights> weightMap;// Open weights filestd::ifstream input(file);assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");// Read number of weight blobsint32_t count;input >> count;assert(count > 0 && "Invalid weight map file.");while (count--){Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size;// Read name and type of blobstd::string name;input >> name >> std::dec >> size;wt.type = DataType::kFLOAT;// Load blobuint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));for (uint32_t x = 0, y = size; x < y; ++x){input >> std::hex >> val[x];}wt.values = val;wt.count = size;weightMap[name] = wt;}return weightMap;
}
此为loadWeight()函数。作者此处使用了std::map容器。map容器在OpenCV和OpenVINO中本身就是大量使用的,所以除了vector之外,也需要掌握map的使用。后面需要往这个型的map中添加权重信息。
同时应该注意,此处的Weights类型在TensorRT的NvInferRuntime.h头文件中有定义:
class Weights
{
public:DataType type; //!< The type of the weights.const void* values; //!< The weight values, in a contiguous array.int64_t count; //!< The number of weights in the array.
};
作者使用了std::ifstream进行输入流变量的定义,并设置了一些变量。代码中的input >> count就是将.wts文件中的第一行的算子数传递给count这个变量,从而构建while循环。
在While循环中,作者先定义了Weights型的wt变量,其类型为DataType::kFLOAT,values直接初始化为nullptr,count初始化一个0在上面即可。
这一句input >> name >> std::dec >> size是将input中的第一部分:权重的名称,赋值给name变量,再将紧跟着name后的size推入给size变量。具体的形式可以参考之前分析gen_wts.py脚本中的权重生成的部分。作者之所以要存入这一算子的权重的size,就是为了方便分配空间大小。声明指针val指向一个大小为sizeof(val) * size的uint32_t的数组,并且将input中这一行的权重全部推入给val这个数组即可。
这一步完成后,设置Weights的values成员为val,count成员为size,并将name作为weightMap的keys,wt作为其values即可。
至此,模型权重加载完毕。
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
