论文实践学习 - Multi-Context Attention for Human Pose Estimation
类似于 论文实践学习 - Stacked Hourglass Networks for Human Pose Estimation ,基于Docker-Torch,估计人体关节点.
这里只简单进行测试估计结果,由于显存有限,未能加入所有的 scale_search.
- [Torch-Code]
- [Pre-trained model]
1. 图片人体姿态估计 - demo.lua
# 输入参数由两个, 第二个参数默认为 'mean'
th demo.lua imglist.txt 'max'
# or
th demo.lua imglist.txt
require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------
a = loadImageNames(arg[1]) -- 批量读取文件名列表m = torch.load( '../checkpoints/mpii/crf_parts/model.t7') -- Load pre-trained model
m:cuda()
m:evaluate()-- Parameters
local isflip = true
local minusmean = tru
local scale_search = {1.0, 1.1} -- 根据显存情况来选择
-- local scale_search = {0.7,0.8,0.9,1.0,1.1,1.2} -- used in paper with NVIDIA Titan X (12 GB memory).-- Displays a convenient progress bar
idxs = torch.range(1, a.nsamples)
nsamples = idxs:nElement() xlua.progress(0,nsamples)
preds = torch.Tensor(nsamples,16,3)
imgs = torch.Tensor(nsamples,3,256,256)
local imgpath = '../data/image/'
--------------------------------------------------------------------------------
-- Main loop
--------------------------------------------------------------------------------
for idx = 1,nsamples do-- Set up input imagelocal imgname = paths.concat(imgpath, a['images'][idxs[idx]])print(imgname)local im = image.load(imgname)local original_scale = 256/200 -- 假设预先已经将图像中人体进行裁剪,并resize到256local center = {128.0, 128.0} local fuseInp = torch.zeros(#scale_search, 3, 256, 256)local hmpyra = torch.zeros(#scale_search, 16, im:size(2), im:size(3))local batch = torch.Tensor(#scale_search, 3, 256, 256)local flipbatch = torch.Tensor(#scale_search, 3, 256, 256)for is, factor in ipairs(scale_search) dolocal scale = original_scale*factorlocal inp = crop(im, center, scale, 0, 256)batch[{is, {}, {}, {}}]:copy(inp)imgs[idx]:copy(inp)end-- minus meanif minusmean thenbatch:add(-0.5)end-- Get network outputlocal out = m:forward(batch:cuda())-- Get flipped outputif isflip thenout = applyFn(function (x) return x:clone() end, out)local flippedOut = m:forward(flip(batch):cuda())flippedOut = applyFn(function (x) return flip(shuffleLR(x)) end, flippedOut)out = applyFn(function (x,y) return x:add(y):div(2) end, out, flippedOut)endcutorch.synchronize()local hm = out[#out]:float()hm[hm:lt(0)] = 0-- Get heatmaps (original image size)for is, scale in pairs(scale_search) dolocal hm_img = getHeatmaps(im, center, original_scale*scale, 0, 256, hm[is])hmpyra[{is, {}, {}, {}}]:copy(hm_img:sub(1, 16))end-- fuse heatmapif arg[2] == 'max' thenfuseHm = hmpyra:max(1)elsefuseHm = hmpyra:mean(1)endfuseHm = fuseHm[1]fuseHm[fuseHm:lt(0)] = 0-- get predictionsfor p = 1,16 dolocal maxy, iy = fuseHm[p]:max(2)local maxv, ix = maxy:max(1)ix = torch.squeeze(ix)preds[idx][p][2] = ixpreds[idx][p][1] = iy[ix]preds[idx][p][3] = maxy[ix]endxlua.progress(idx, nsamples)collectgarbage()
end-- Save predictions
local predFile = hdf5.open('../preds/preds.h5', 'w')
predFile:write('preds', preds)
predFile:write('imgs', imgs)
predFile:close()
2. 人体姿态估计可视化 - show.py
#!/usr/bin/env python
import h5py
import scipy.misc as scm
import matplotlib.pyplot as pltJointsIndex = {'r_ankle': 0, 'r_knee': 1, 'r_hip': 2,'l_hip': 3, 'l_knee': 4, 'l_ankle': 5,'pelvis': 6, 'thorax': 7, 'neck': 8, 'head': 9,'r_wrist': 10, 'r_elbow': 11, 'r_shoulder': 12,'l_shoulder': 13, 'l_elbow': 14, 'l_wrist': 15}
JointPairs = [['head', 'neck'], ['neck', 'thorax'],['thorax', 'r_shoulder'], ['thorax', 'l_shoulder'], \['r_shoulder', 'r_elbow'], ['r_elbow', 'r_wrist'],['l_shoulder', 'l_elbow'], ['l_elbow', 'l_wrist'], \['pelvis', 'r_hip'], ['pelvis', 'l_hip'], ['r_hip', 'r_knee'],['r_knee', 'r_ankle'], \['l_hip', 'l_knee'], ['l_knee', 'l_ankle'],['thorax', 'pelvis']]
StickType = ['r-', 'r-', 'g-', 'b-', 'g-', 'g-', 'b-', 'b-', 'c-', 'm-','c-', 'c-', 'm-', 'm-', 'r-']imgs = open('../test/imglist.txt','r').readlines()
images_path = '../data/image/'f = h5py.File('preds.h5','r')
f_keys = f.keys()
#imgs = f['imgs'][:]
preds = f['preds'][:]
f.close()assert len(imgs) == len(preds)
for i in range(len(imgs)):filename = images_path + imgs[i][:-1]img = scm.imread(filename)pose = preds[i]# img = imgs[i].transpose(1,2,0)plt.axis('off')plt.imshow(img)# for i in range(16):# if pose[i][0] > 0 and pose[i][1] > 0:# plt.scatter(pose[i][0], pose[i][1], marker='o', color='r', s=15)# plt.show()for i in range(len(JointPairs)):idx1 = JointsIndex[JointPairs[i][0]]idx2 = JointsIndex[JointPairs[i][1]]if pose[idx1][0] > 0 and pose[idx1][1] > 0 and \pose[idx2][0] > 0 and pose[idx2][1] > 0:joints_x = [pose[idx1][0], pose[idx2][0]]joints_y = [pose[idx1][1], pose[idx2][1]]plt.plot(joints_x, joints_y, StickType[i], linewidth=3)plt.show()print 'Done.'
3. Results
理想的结果
不理想的结果(可能因为scales不足造成)
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
