torch.size unsqueeze()

import torch
obs=[2,4,2,6]obs_tensor = torch.as_tensor(obs, dtype=torch.float32)
print((obs_tensor))
print((obs_tensor).shape)
obs_tensor.unsqueeze(0)  #在0的位置加上一维
print(obs_tensor.unsqueeze(0))
print(obs_tensor.unsqueeze(0).shape)

output

tensor([2., 4., 2., 6.])
torch.Size([4])
tensor([[2., 4., 2., 6.]])
torch.Size([1, 4])

torch.Size括号中有几个数字就是几维,具体参考——torch.size: link

**unsqueeze()**这个函数主要是对数据维度进行扩充。
给指定位置加上维数为一的维度,比如原本有个四行的数据(4),unsqueeze(0)后就会在0的位置加了一维就变成一行四列(1,4)。参考链接: link 和 link


本文来自互联网用户投稿,文章观点仅代表作者本人,不代表本站立场,不承担相关法律责任。如若转载,请注明出处。 如若内容造成侵权/违法违规/事实不符,请点击【内容举报】进行投诉反馈!

相关文章

立即
投稿

微信公众账号

微信扫一扫加关注

返回
顶部