主页 > 游戏开发  > 

pytorch中torch.stack()用法虽简单,但不好理解


函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

参数列表

tensors :为一系列输入张量,类型为turple和List dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接 返回值:输出新增维度后的张量

情况一:输入数据为1维数据

dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) #在第0维进行连接,相当于在行上进行组合,取a的一行,b的一行,构成一个新的tensor(输入张量为一维,输出张量为两维) c = torch.stack([a, b],dim=0) print(a) print(b) print(c.size()) print(c) 输出: tensor([1, 2, 3]) tensor([11, 22, 33]) torch.Size([2, 3]) tensor([[ 1, 2, 3], [11, 22, 33]])

dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

import torch a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) print(a) print(b) #在第1维进行连接,相当于在对应行上面对列元素进行组合,取a的一列,b的一列,构成新的tensor的一行(输入张量为一维,输出张量为两维) c = torch.stack([a, b],dim=1) print(c.size()) print(c) 输出: tensor([1, 2, 3]) tensor([11, 22, 33]) torch.Size([3, 2]) tensor([[ 1, 11], [ 2, 22], [ 3, 33]])

情况二:输入数据为2维数据

dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) print(a) print(b) #在第0维进行连接,相当于在通道维度上进行组合 #即取a的所有数据,作为新tensor的一个分量 #取b的所有数据,作为新tensor的另一个分量 #(输入张量为两维,输出张量为三维) c = torch.stack([a, b],dim=0) print(c.size()) print(c) 输出: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) torch.Size([2, 3, 3]) tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[11, 22, 33], [44, 55, 66], [77, 88, 99]]])

dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) print(a) print(b) #在第1维(行)进行连接,相当于对相应通道中每个行进行组合 #取a的一行,b的一行,作为新tensor的第1行和第2行 #原来a:3*3,b:3*3,新tensor:3*2*3 c = torch.stack([a, b], 1) print(c.size()) print(c) 输出: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) torch.Size([3, 2, 3]) tensor([[[ 1, 2, 3], [11, 22, 33]], [[ 4, 5, 6], [44, 55, 66]], [[ 7, 8, 9], [77, 88, 99]]])

dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) print(a) print(b) #在第2维进行连接,相当于对相应行中每个列元素进行组合 #针对每行,取a、b的第一列数据,构成tensor的第一行 #针对每行,取a、b的第二列数据,构成tensor的第二行 #,针对每行取a、b的第三列数据,构成tensor的第三行 #原来a:3*3,b:3*3,新tensor:3*3*2 c = torch.stack([a, b], 2) print(c.size()) print(c) 输出: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) torch.Size([3, 3, 2]) tensor([[[ 1, 11], [ 2, 22], [ 3, 33]], [[ 4, 44], [ 5, 55], [ 6, 66]], [[ 7, 77], [ 8, 88], [ 9, 99]]])

情况三:输入数据为3维数据

dim=0:表示在第0维进行连接,相当于在通道维进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

import torch a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) print(a) print(b) #表示在第0维进行连接,取整个a作为新tensor的一个分量,取整个b作为新tensor的一个分量 c = torch.stack([a, b], 0) print(c) 输出: tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) torch.Size([2, 3, 3]) tensor([[[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) torch.Size([2, 3, 3]) torch.Size([2, 2, 3, 3]) tensor([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[ 10, 20, 30], [ 40, 50, 60], [ 70, 80, 90]]], [[[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]])

dim=1:表示在第1维进行连接,取各自的第1维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。 

import torch a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) print(a) print(a.size()) print(b) print(b.size()) #表示在第1维进行连接,取a的第一维数据[[1, 2, 3], [4, 5, 6], [7, 8, 9]] #取b的第一维数据[[11, 22, 33], [44, 55, 66], [77, 88, 99]]作为新tensor的一个分量 #取a的第一维数据[[10, 20, 30], [40, 50, 60], [70, 80, 90]] #取b的第一维数据[[110, 220, 330], [440, 550, 660], [770, 880, 990]]作为新tensor的另一个分量 c = torch.stack([a, b], 1) print(c.size()) print(c) 输出: tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) torch.Size([2, 3, 3]) tensor([[[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) torch.Size([2, 3, 3]) torch.Size([2, 2, 3, 3]) tensor([[[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]]], [[[ 10, 20, 30], [ 40, 50, 60], [ 70, 80, 90]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]])

dim=2:表示在第2维进行连接,取各自的第2维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

import torch a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) print(a) print(a.size()) print(b) print(b.size()) #表示在第1维进行连接,取a的第2维数据[1, 2, 3] #取b的第2维数据[11, 22, 33]作为新tensor的一个分量 #取a的第2维数据[4, 5, 6] #取b的第2维数据[44, 55, 66]作为新tensor的一个分量 #取a的第2维数据[4, 5, 6] #取b的第2维数据[44, 55, 66]作为新tensor的一个分量 #取a的第2维数据[7, 8, 9] #取b的第2维数据[77, 88, 99]作为新tensor的一个分量 #取a的第2维数据[10, 20, 30] #取b的第2维数据[110, 220, 330]作为新tensor的一个分量 #取a的第2维数据[40, 50, 60] #取b的第2维数据[440, 550, 660]作为新tensor的一个分量 #取a的第2维数据[70, 80, 90] #取b的第2维数据[770, 880, 990]作为新tensor的一个分量 c = torch.stack([a, b], 2) print(c.size()) print(c) 输出: tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) torch.Size([2, 3, 3]) tensor([[[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) torch.Size([2, 3, 3]) torch.Size([2, 3, 2, 3]) tensor([[[[ 1, 2, 3], [ 11, 22, 33]], [[ 4, 5, 6], [ 44, 55, 66]], [[ 7, 8, 9], [ 77, 88, 99]]], [[[ 10, 20, 30], [110, 220, 330]], [[ 40, 50, 60], [440, 550, 660]], [[ 70, 80, 90], [770, 880, 990]]]])

dim=3:表示在第3维进行连接,取各自的第3维度数据,进行拼接。注意:此处输入张量维度为三维,因此dim最大只能为3。

import torch a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) print(a) print(a.size()) print(b) print(b.size()) #针对第二维数据,在每个第二维度相同的情况下,取各自的列数据,构成新tensor的一行 c = torch.stack([a, b], 3) print(c.size()) print(c) 输出: tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) torch.Size([2, 3, 3]) tensor([[[ 11, 22, 33], [ 44, 55, 66], [ 77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) torch.Size([2, 3, 3]) torch.Size([2, 3, 3, 2]) tensor([[[[ 1, 11], [ 2, 22], [ 3, 33]], [[ 4, 44], [ 5, 55], [ 6, 66]], [[ 7, 77], [ 8, 88], [ 9, 99]]], [[[ 10, 110], [ 20, 220], [ 30, 330]], [[ 40, 440], [ 50, 550], [ 60, 660]], [[ 70, 770], [ 80, 880], [ 90, 990]]]])

总结:m个序列数据,在某个维度k进行拼接,该维度大小为n,则拼接后形成了*n*m*大小,具体拼接过程是取m个序列数据,k-1维(设k-1维大小为x,从x=1开始取)相同情况下的第1个数据,构成新tensor的一个行;第二个数据...,第三个数据...构成tensor的新行;然后从x=2开始执行同样的操作

标签:

pytorch中torch.stack()用法虽简单,但不好理解由讯客互联游戏开发栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“pytorch中torch.stack()用法虽简单,但不好理解