torch.roll函数学习
torch.roll函数是真的比较难以理解,我觉得之后我碰上可能也不一定能转过弯来,因此写一篇博客记录一下。
torch.roll的文档在这里。从官方文档可以看到,torch.roll(input,shifts,dims)
中的三个参数意思都是比较明确的,input即为输入的tensor,shifts表示位移的距离,dims为位移的方向。其中shifts和dims既可以为数字,也可以为元组。
其中最让人困惑的莫过于dims,特别是在高维(大于3维)的时候,基本上感觉怎么移动都不对味。比如我目前需要对一个形状为(batch_size, time_length, feature_size)的向量在时间维度上进行迁移,感觉就怎么都不对。
官方例子的改版:
import torch
x = torch.arange(0,9).view(3,3)
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
torch.roll(x,-1,1)
# tensor([[2, 3, 1],
# [5, 6, 4],
# [8, 9, 7]])
print(x[:,0])
# tensor([1, 4, 7])
可以看到,上面这个例子在dim=1的维度进行操作,而最终这个维度上是没有发生变化的(其他维度上均发生位移)。
在三维上的例子
x = torch.arange(0,27).view(3,3,3)
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8]],
# [[ 9, 10, 11],
# [12, 13, 14],
# [15, 16, 17]],
# [[18, 19, 20],
# [21, 22, 23],
# [24, 25, 26]]])
x = torch.roll(x,-1,1)
# tensor([[[ 3, 4, 5],
# [ 6, 7, 8],
# [ 0, 1, 2]],
# [[12, 13, 14],
# [15, 16, 17],
# [ 9, 10, 11]],
# [[21, 22, 23],
# [24, 25, 26],
# [18, 19, 20]]])
# 效果是等价的
# >>> torch.roll(x[0],-1,0)
# tensor([[3, 4, 5],
# [6, 7, 8],
# [0, 1, 2]])
通过最后的备注语句可以看到,在三维的情况下相当于是对二维的情况进行了广播,这里用PPT简单画了一下。其中的dim=1相当于图中的y-z平面。因此可以显而易见的看到,沿dim=1进行roll,相当于是把y-z平面顺次平移。