torch.gather()取每行中不同列的元素
pytorch取每行中不同列的元素
import torch
scores = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]
])
label=torch.LongTensor([ [0],[1],[2] ])
ans = scores.gather(1, label)
print(ans)

常用场合:信息检索或者推荐系统模型中计算指标要获得item 或者 document经过模型排序后的结果
对item算score,然后要对score算排名,最后根据排名取出前十个item
#item为id,在model里通过item id取出 embedding
score = model(item) # item, score shape [batch, 100]
ranks = torch,argsort(score, dim=-1, descending=True) # 降序排列
ranks = ranks[:, :10] #对每个样本取前十
ids = item.gather(1, ranks) # [batch, 10] ids 为item经过model计算分数后排序的前十物品
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
