torch.einsum函数学习
目录
torch.einsum
简单例子
基础知识见csdn,这篇博客写的比较好。简而言之,einsum就是爱因斯坦求和简记法的实现,这里引用该文的一个例子,非常好懂。
print(a_tensor)
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
print(b_tensor)
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
# 'ik, kj -> ij'语义解释如下:
# 输入a_tensor: 2维数组,下标为ik,
# 输入b_tensor: 2维数组,下标为kj,
# 输出output:2维数组,下标为ij。
# 隐含语义:输入a,b下标中相同的k,是求和的下标,对应上面的例子2的公式
output = torch.einsum('ik, kj -> ij', a_tensor, b_tensor)
print(output)
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
高维案例
但是对于高维案例来说,这个简记法就么那么直观了。同样引用原文的例子
a = np.arange(60.).reshape(3,4,5)
b = np.arange(24.).reshape(4,3,2)
# 语义解析:
# 输入a:3阶张量,下标为ijk
# 输入b: 3阶张量,下标为jil
# 输出o: 2阶张量,下标为k和l
# 隐含语义:对i,j进行求和,公式附于代码之后:
o = np.einsum('ijk,jil->kl', a, b)
print(o)
array([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])
# 验证:
print(np.sum(a[:,:,0]*b[:,:,0].T))
4400.0
print(np.sum(a[:,:,1]*b[:,:,0].T))
4532.0
上述式子从k,l -> kl可以猜想到其计算过程o[k,l] = a[i,j,k] * b[j,i,l]
,然后每个元素分别计算出对应的值。
更加复杂
但是对于更复杂的例子,就需要进一步思考。下面对于Informer中的attention计算为例子进行进一步学习。
queries = torch.arange(0,120).view(2,3,4,5)
keys = torch.arange(0,120).view(2,3,4,5)
values = torch.arange(0,120).view(2,3,4,5)
B, L, H, E = queries.shape
_, S, _, D = values.shape
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
result = torch.zeros_like(scores)
for b in range(B):
for l in range(L):
for h in range(H):
for s in range(S):
for e in range(E):
result[b,h,l,s] += queries[b,l,h,e]*keys[b,s,h,e]
# result[b,h,l,s] = torch.sum(queries[b,l,h,:]*keys[b,s,h,:])
从这个例子中可以看出求和符号的计算。