Pytorch基础:torch.expand() 和 torch.repeat()

在torch中,如果要改变某一个tensor的维度,可以利用viewexpandrepeattransposepermute等方法,这里对这些方法的一些容易混淆的地方做个总结。

expand和repeat函数是pytorch中常用于进行张量数据复制维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1

expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

import torch
 
a = torch.tensor([1, 0, 2])     # a -> torch.Size([3])
b1 = a.expand(2, -1)            # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],
        [1, 0, 2]])
'''
 
a = torch.tensor([[1], [0], [2]])   # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2)                 # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为  tensor([[1, 1],
             [0, 0],
             [2, 2]])
'''
 
a = torch.Tensor([[1, 2, 3]])   # a -> torch.Size([1, 3])
b3 = a.expand(4, 3)              # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,
                                # 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor(
	[[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.]]
)
'''
 
a = torch.Tensor([[1, 2, 3], [4, 5, 6]])  # a -> torch.Size([2, 3])
b4 = a.expand(4, 6)  # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''
 
b5 = a.expand(1, 2, 3)  # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
 
b7 = a.expand(2, 3, 2)  # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''
 
b8 = a.expand(2, -1, -1)  # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
 
# expand返回的张量与原版张量具有相同内存地址
print(b8.storage())  # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

沿着特定维度扩展张量,并返回扩展后的张量

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变
import torch
 
if __name__ == '__main__':
    x = torch.rand(2, 3)
    y1 = x.repeat(4, 2)
    print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

  • torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

  • torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

import torch
 
if __name__ == '__main__':
    x = torch.rand(1, 3)
    y1 = x.expand(4, 3)
    y2 = x.repeat(2, 3)
    print(x.storage().data_ptr(), y1.storage().data_ptr())  # 52364352 52364352
    print(x.storage().data_ptr(), y2.storage().data_ptr())  # 52364352 8852096

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/604565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

gitlab集群高可用架构拆分部署

目录 前言 负载均衡器准备 外部负载均衡器 内部负载均衡器 (可选)Consul服务 Postgresql拆分 1.准备postgresql集群 手动安装postgresql插件 2./etc/gitlab/gitlab.rb配置 3.生效配置文件 Redis拆分 1./etc/gitlab/gitlab.rb配置 2.生效配置文件 Gitaly拆分 1.…

BACnet转MQTT网关智联楼宇json格式自定义

智能建筑的BACnet协议作为楼宇自动化领域的通用语言,正逐步迈向更广阔的物联网世界。随着云计算和大数据技术的飞速发展,如何将BACnet设备无缝融入云端生态系统,成为众多楼宇管理者关注的焦点。本文将以一个实际案例,揭示BACnet网…

60、郑州大学附属肿瘤医院 :用于预测胃癌患者术后生存的深度学习模型的开发和验证[同学,我们的人生应当是旷野]

馒头老师要说的话: 我近期看了一下北京的脑机公司,大概是我之前对这一行业太过于乐观,北京的BCI公司和研究所,比上海、深圳、杭州甚至是重庆都要少,门槛也要高很多。也有我自己的原因,有时站的太高&#x…

92、动态规划-最小路径和

思路: 还是一样,先使用递归来接,无非是向右和向下,然后得到两种方式进行比较,代码如下: public int minPathSum(int[][] grid) {return calculate(grid, 0, 0);}private int calculate(int[][] grid, int …

ubuntu_Docker安装配置

什么是docker? Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中,然后发布到任何流行的 Linux或Windows操作系统的机器上,也可以实现虚拟化。容器是完全使用沙箱机制,相互之间不会有…

为什么要梯度累积

文章目录 梯度累积什么是梯度累积如何理解理解梯度累积梯度累积的工作原理 梯度累积的数学原理梯度累积过程如何实现梯度累积 梯度累积的可视化 梯度累积 什么是梯度累积 随着深度学习模型变得越来越复杂,模型的训练通常需要更多的计算资源,特别是在训…

深度学习笔记_10YOLOv8系列自定义数据集实验

1、mydaya.yaml 配置方法 # 这里分别指向你训练、验证、测试的文件地址,只需要指向图片的文件夹即可。但是要注意图片和labels名称要对应 # 训练集、测试集、验证机文件路径,可以是分类好的TXT文件,也可以直接是图片文件夹路径 train: # t…

Litedram仿真验证(四):AXI接口完成板级DDR3读写测试(FPGA-Artix7)

目录 日常唠嗑一、仿真中遗留的问题二、板级测试三、工程获取及交流 日常唠嗑 接上一篇Litedram仿真验证(三):AXI接口完成仿真(FPGA/Modelsim)之后,本篇对仿真后的工程进行板级验证。 本次板级验证用到的开…

学成在线 - 第3章任务补偿机制实现 + 分块文件清理

7.9 额外实现 7.9.1 任务补偿机制 问题:如果有线程抢占了某个视频的处理任务,如果线程处理过程中挂掉了,该视频的状态将会一直是处理中,其它线程将无法处理,这个问题需要用补偿机制。 单独启动一个任务找到待处理任…

Layer1 公链竞争破局者:Sui 生态的全面创新之路

随着 Sui 生态逐渐在全球范围内树立起声望,并通过与 Revolut 等前沿金融科技平台合作,推广区块链教育与应用,Sui 生态的未来发展方向已成为业界瞩目的焦点。如今,Sui 的总锁定价值已攀升至 5.93 亿美元,充分展示了其在…

分布式架构的演技进过程

最近看了一篇文章,觉得讲的挺不错,就借机给大家分享一下。 早期应用:早期的应用比较简单,访问人数有限,大部分的开发单机就能完成。 分离模型:在业务发展后,用户数量逐步上升,服务器的性能出现瓶颈;就需要将应用和数据分开存储,避免相互抢占资源。 缓存模式:随着系…

历代著名画家作品赏析-东晋顾恺之

中国历史朝代顺序为:夏朝、商朝、西周、东周、秦朝、西楚、西汉、新朝、玄汉、东汉、三国、曹魏、蜀汉、孙吴、西晋、东晋、十六国、南朝、刘宋、南齐、南梁、南陈、北朝、北魏、东魏、北齐、西魏、北周、隋,唐宋元明清,近代。 一、东晋著名…

现身说法暑期三下乡社会实践团一个好的投稿方法胜似千军万马

作为一名在校大学生,去年夏天我有幸参与了学院组织的暑期大学生三下乡社会实践活动,这段经历不仅让我深入基层,体验了不一样的生活,更是在新闻投稿的实践中,经历了一次从传统到智能的跨越。回忆起那段时光,从最初的邮箱投稿困境,到后来智慧软文发布系统的高效运用,每一步都刻印…

顺序栈的操作

归纳编程学习的感悟, 记录奋斗路上的点滴, 希望能帮到一样刻苦的你! 如有不足欢迎指正! 共同学习交流! 🌎欢迎各位→点赞 👍 收藏⭐ 留言​📝既然选择了远方,当不负青春…

什么是DDoS攻击?DDoS攻击的原理是什么?

一、DDoS攻击概念 DDoS攻击又叫“分布式拒绝服务”(Distributed DenialofService)攻击,它是一种通过控制大量计算机、物联网终端或网络僵尸(Zombie)来向目标网站发送大量请求,从而耗尽其服务器资源,导致正常用户无法访…

WEB基础--JDBC操作数据库

使用JDBC操作数据库 使用JDBC查询数据 五部曲:建立驱动,建立连接,获取SQL语句,执行SQL语句,释放资源 建立驱动 //1.加载驱动Class.forName("com.mysql.cj.jdbc.Driver"); 建立连接 //2.连接数据库 Stri…

【3dmax笔记】026:挤出和壳修改器的使用

文章目录 一、修改器二、挤出三、壳 一、修改器 3ds Max中的修改器是一种强大的工具,用于创建和修改复杂的几何形状。这些修改器可以改变对象的形状、大小、方向和位置,以生成所需的效果。以下是一些常见的3ds Max修改器及其功能: 挤出修改…

Google Earth Engine谷歌地球引擎计算遥感影像在每个8天间隔内的多年平均值

本文介绍在谷歌地球引擎(Google Earth Engine,GEE)中,求取多年时间中,遥感影像在每1个8天时间间隔内的多年平均值的方法。 本文是谷歌地球引擎(Google Earth Engine,GEE)系列教学文章…

神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、…

JavaScript数字分隔符

● 如果现在我们用一个很大的数字,例如2300000000,这样真的不便于我们进行阅读,我们希望用千位分隔符来隔开它,例如230,000,000; ● 下面我们使用_当作分隔符来尝试一下 const diameter 287_266_000_000; console.log(diameter)…
最新文章