torch中transpose和permute转置问题

在pytorch中转置用的函数就只有这两个
transpose():操作不了多维
permute():可以操作多维

百度搜索的教程中的理论太多太多,不如直接代码测试容易让人理解,话不多说,直接代码尝试:

t = torch.randn(2, 4, 5)   # 首先创建2个正态分布的4*5的矩阵
print(t)
tensor([[[ 1.0224,  0.5716, -1.2172, -0.0534, -1.0312],[ 0.0622, -0.0260,  2.6485, -0.9420, -0.1987],[-0.6560,  0.0956, -2.2045, -0.6329,  2.3294],[-0.0351,  1.0526, -0.1086, -1.1315, -0.2870]],[[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479],[-0.0086, -1.2145,  2.0289,  0.5889, -0.2644],[ 0.1313,  0.2485, -1.1323, -0.8699,  0.2849],[ 0.3727, -0.0079,  0.3927,  1.4980,  0.5328]]])
# randn(2,4,5)中三个数的索引分别为 0,1,2
t1=t.transpose(1,0) #此时transpose中的参数1,0表示交换t中的索引位置,t1也既是(4,2,5)表示4个2*5的矩阵
# t1=t.transpose(0,1)也是和上述一样的意思,交换0,1的位置,结果都是下图结果
print(t1)
tensor([[[ 1.0224,  0.5716, -1.2172, -0.0534, -1.0312],[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479]],[[ 0.0622, -0.0260,  2.6485, -0.9420, -0.1987],[-0.0086, -1.2145,  2.0289,  0.5889, -0.2644]],[[-0.6560,  0.0956, -2.2045, -0.6329,  2.3294],[ 0.1313,  0.2485, -1.1323, -0.8699,  0.2849]],[[-0.0351,  1.0526, -0.1086, -1.1315, -0.2870],[ 0.3727, -0.0079,  0.3927,  1.4980,  0.5328]]])

同理,permute道理相通,只是可以操作多维而已,但必须传入所有维度数。

t = torch.randn(2, 4, 5)   # 首先创建2个正态分布的4*5的矩阵
t
tensor([[[ 0.2464, -1.5848,  0.4432, -0.8214, -1.3044],[-0.0355, -0.4341,  0.3624, -1.4011,  0.0111],[ 1.3601,  0.1008, -1.4646,  0.2118,  0.1643],[ 1.9176, -0.0868,  0.8551,  0.4760, -1.5810]],[[ 0.4147, -1.2642,  1.1018,  0.4975, -0.3797],[-1.0450,  1.0998, -0.8400,  0.5221,  1.0553],[-0.7401,  1.4456,  0.9995, -0.6732, -0.5768],[ 1.0525,  0.5885,  1.3591, -0.3551, -1.4941]]])
t1=t.permute(1,0,2)#必须与t中参数数目一致,执行完这句之后含义表示randn(4,2,5) 4个2*5的矩阵
t1
tensor([[[ 0.2464, -1.5848,  0.4432, -0.8214, -1.3044],[ 0.4147, -1.2642,  1.1018,  0.4975, -0.3797]],[[-0.0355, -0.4341,  0.3624, -1.4011,  0.0111],[-1.0450,  1.0998, -0.8400,  0.5221,  1.0553]],[[ 1.3601,  0.1008, -1.4646,  0.2118,  0.1643],[-0.7401,  1.4456,  0.9995, -0.6732, -0.5768]],[[ 1.9176, -0.0868,  0.8551,  0.4760, -1.5810],[ 1.0525,  0.5885,  1.3591, -0.3551, -1.4941]]])t2=t.permute(2,0,1)#表示randn(5,2,4) 5个2*4的矩阵
t2
tensor([[[ 0.2464, -0.0355,  1.3601,  1.9176],[ 0.4147, -1.0450, -0.7401,  1.0525]],[[-1.5848, -0.4341,  0.1008, -0.0868],[-1.2642,  1.0998,  1.4456,  0.5885]],[[ 0.4432,  0.3624, -1.4646,  0.8551],[ 1.1018, -0.8400,  0.9995,  1.3591]],[[-0.8214, -1.4011,  0.2118,  0.4760],[ 0.4975,  0.5221, -0.6732, -0.3551]],[[-1.3044,  0.0111,  0.1643, -1.5810],[-0.3797,  1.0553, -0.5768, -1.4941]]])


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部