- Published on
PyTorch矩阵乘法
$*$:element-wise乘法,对应元素相乘
支持broadcast操作
a = torch.tensor(torch.ones(10,1))
b = torch.tensor(torch.rand(5)) # tensor([0.4161, 0.4143, 0.7171, 0.4200, 0.6486])
a * b
# 输出:
tensor([[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
[0.4161, 0.4143, 0.7171, 0.4200, 0.6486]])
$@$:矩阵乘法,不支持broadcast操作
a = torch.tensor(torch.ones(10,1))
b = torch.tensor(torch.rand(5))
a@b
# 输出: RuntimeError: size mismatch, [10 x 1], [5]
a = torch.tensor(torch.ones(10,5))
b = torch.tensor(torch.rand(5))
a@b # shape [10]
a = torch.tensor(torch.ones(10,5))
b = torch.tensor(torch.rand(5,1))
a@b # shape [10,1]
二维矩阵乘法 mm()
不支持broadcast
torch.mm(mat1, mat2, out=None)
,用于矩阵乘法,同@,
三维带batch的矩阵乘法 bmm()
不支持broadcast
torch.bmm()
,在神经网络第一维度通常是batch,batch不参加运算,所以只是第二三维度进行矩阵乘法,
b×n×m``b×m×n
--> b×n×n
element-wise mul()
和$*$一样,支持broadcast操作。
多维矩阵乘法 matmul()
支持broadcast
matmul()
乘法对参数的后两个维度计算(矩阵乘法),其他维度看作batch,可以进行广播
(1000, 200, 20,10) (200, 10, 20) 最后计算的维度(1000, 200, 20,20)
broadcast
可广播的一对张量满足一下规则:
- 每个张量至少有一个维度
- 迭代维度尺寸,从尾部维度开始,存在三种情况,两个张量维度相等;其中一个张量的维度尺寸为1;其中一个张量不存在这个维度
进行广播之后,结果是两个张量对应维度尺寸的较大者