pytorch交换tensor的指定维度

pytorch中有两种方式可以实现tensor指定维度的交换,第一个是torch.permute(),第二个方法是torch.transpose()。

二者不同是torch.permute()可以同时交换多个维度,而torch.transpose()每次只能交换两个维度。

方式一:torch.permute()

参数列表:

  • input:待交换的张量
  • dims:需要交换维度的索引

该函数会按照我们指定维度方式重新排列,例如我们下面定义了一个张量维度为【2,3,4】,如果我们要将维度变为【4,3,2】,就需要交换第一个维度和第三个维度,那我们传入的参数维度索引就应该为【2,1,0】,该索引对应维度的顺序,原来是【0,1,2】,现在是【2,1,0】。

a = torch.randn(2, 3, 4)
print(a.shape)print(torch.permute(a, (2, 1, 0)).shape)
torch.Size([2, 3, 4])
torch.Size([4, 3, 2])

或者还可以使用transpose函数直接交换维度

方式二:torch.transpose()

对于该方法每次只能交换两个维度,输入的参数很简单就是需要交换的两个维度的索引。

print(torch.transpose(a, 0, 1).shape)
torch.Size([3, 2, 4])


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部