torch中transpose和permute转置问题
在pytorch中转置用的函数就只有这两个
transpose():操作不了多维
permute():可以操作多维
百度搜索的教程中的理论太多太多,不如直接代码测试容易让人理解,话不多说,直接代码尝试:
t = torch.randn(2, 4, 5) # 首先创建2个正态分布的4*5的矩阵
print(t)
tensor([[[ 1.0224, 0.5716, -1.2172, -0.0534, -1.0312],[ 0.0622, -0.0260, 2.6485, -0.9420, -0.1987],[-0.6560, 0.0956, -2.2045, -0.6329, 2.3294],[-0.0351, 1.0526, -0.1086, -1.1315, -0.2870]],[[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479],[-0.0086, -1.2145, 2.0289, 0.5889, -0.2644],[ 0.1313, 0.2485, -1.1323, -0.8699, 0.2849],[ 0.3727, -0.0079, 0.3927, 1.4980, 0.5328]]])
# randn(2,4,5)中三个数的索引分别为 0,1,2
t1=t.transpose(1,0) #此时transpose中的参数1,0表示交换t中的索引位置,t1也既是(4,2,5)表示4个2*5的矩阵
# t1=t.transpose(0,1)也是和上述一样的意思,交换0,1的位置,结果都是下图结果
print(t1)
tensor([[[ 1.0224, 0.5716, -1.2172, -0.0534, -1.0312],[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479]],[[ 0.0622, -0.0260, 2.6485, -0.9420, -0.1987],[-0.0086, -1.2145, 2.0289, 0.5889, -0.2644]],[[-0.6560, 0.0956, -2.2045, -0.6329, 2.3294],[ 0.1313, 0.2485, -1.1323, -0.8699, 0.2849]],[[-0.0351, 1.0526, -0.1086, -1.1315, -0.2870],[ 0.3727, -0.0079, 0.3927, 1.4980, 0.5328]]])
同理,permute道理相通,只是可以操作多维而已,但必须传入所有维度数。
t = torch.randn(2, 4, 5) # 首先创建2个正态分布的4*5的矩阵
t
tensor([[[ 0.2464, -1.5848, 0.4432, -0.8214, -1.3044],[-0.0355, -0.4341, 0.3624, -1.4011, 0.0111],[ 1.3601, 0.1008, -1.4646, 0.2118, 0.1643],[ 1.9176, -0.0868, 0.8551, 0.4760, -1.5810]],[[ 0.4147, -1.2642, 1.1018, 0.4975, -0.3797],[-1.0450, 1.0998, -0.8400, 0.5221, 1.0553],[-0.7401, 1.4456, 0.9995, -0.6732, -0.5768],[ 1.0525, 0.5885, 1.3591, -0.3551, -1.4941]]])
t1=t.permute(1,0,2)#必须与t中参数数目一致,执行完这句之后含义表示randn(4,2,5) 4个2*5的矩阵
t1
tensor([[[ 0.2464, -1.5848, 0.4432, -0.8214, -1.3044],[ 0.4147, -1.2642, 1.1018, 0.4975, -0.3797]],[[-0.0355, -0.4341, 0.3624, -1.4011, 0.0111],[-1.0450, 1.0998, -0.8400, 0.5221, 1.0553]],[[ 1.3601, 0.1008, -1.4646, 0.2118, 0.1643],[-0.7401, 1.4456, 0.9995, -0.6732, -0.5768]],[[ 1.9176, -0.0868, 0.8551, 0.4760, -1.5810],[ 1.0525, 0.5885, 1.3591, -0.3551, -1.4941]]])t2=t.permute(2,0,1)#表示randn(5,2,4) 5个2*4的矩阵
t2
tensor([[[ 0.2464, -0.0355, 1.3601, 1.9176],[ 0.4147, -1.0450, -0.7401, 1.0525]],[[-1.5848, -0.4341, 0.1008, -0.0868],[-1.2642, 1.0998, 1.4456, 0.5885]],[[ 0.4432, 0.3624, -1.4646, 0.8551],[ 1.1018, -0.8400, 0.9995, 1.3591]],[[-0.8214, -1.4011, 0.2118, 0.4760],[ 0.4975, 0.5221, -0.6732, -0.3551]],[[-1.3044, 0.0111, 0.1643, -1.5810],[-0.3797, 1.0553, -0.5768, -1.4941]]])
本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!
