Pytorch基础 – 1. torch.squeeze() 和 unsqueeze()

导读:本篇文章讲解 Pytorch基础 – 1. torch.squeeze() 和 unsqueeze(),希望对大家有帮助,欢迎收藏,转发!站点地址:www.bmabk.com

tensor升维和降维是神经网络的基本操作,比如不同维feature融合等都需要改操作。常用的函数有torch.unsqueeze() 和 torch.unsqueeze()操作。

目录

1. tensor降维操作: torch.squeeze() 和 指定index 

2. tensor升维操作: torch.unsqueeze() 和 使用None

 3. torch.squeeze和torch.unsqueeze的另一种写法


1. tensor降维操作: torch.squeeze() 和 指定index 

(1) 使用torch.squeeze(input,dim),默认删除tensor中所有维度为1的维度,也可指定dim。torch.squeeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])
    a2 = torch.squeeze(a, dim=1)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])
    a3 = torch.squeeze(a, dim=3)
    print(a3.shape)  # torch.Size([2, 1, 3, 4])

(2) 也可使用index=0直接指定,使用torch.equal比较两者相等。

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])

    a2 = a[:, 0, :, 0]
    print(a2.shape)  # torch.Size([2, 3, 4])

    print(torch.equal(a1, a2))  # True

2. tensor升维操作: torch.unsqueeze() 和 使用None

(1) torch.unsqueeze(input, dim) ,对指定的dim,执行升维操作,具体可参考官方文档以及如下示例。torch.unsqueeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=1)
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = torch.unsqueeze(a, dim=2)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

(2) 简单用法:使用None,使用None来增加新维度

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = a[:, None, ...]
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = a[..., None, :]
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

注意:a1中None后面的三个点可以省略,如下

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))

    a1_old = a[:, None, ...]
    print(a1_old .shape)  # torch.Size([2, 1, 3, 4])
    a1_new = a[:, None]
    print(a1_new .shape)  # torch.Size([2, 1, 3, 4])

    print(torch.equal(a1_old, a1_new))  # True

 3. torch.squeeze和torch.unsqueeze的另一种写法

一般情况下使用torch.squeeze(x, dim=?)来进行降维,当然还可以直接使用 x.squeeze(dim=?)。

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=0)
    print(a1.shape)  # torch.Size([1, 2, 3, 4])
    # 另一种写法
    a2 = a.unsqueeze(dim=0)
    print(a2.shape)  # torch.Size([1, 2, 3, 4])

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/87439.html

(0)
小半的头像小半

相关推荐

极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!