Published on

PyTorch矩阵乘法

$*$:element-wise乘法,对应元素相乘

支持broadcast操作

a = torch.tensor(torch.ones(10,1))
b = torch.tensor(torch.rand(5)) # tensor([0.4161, 0.4143, 0.7171, 0.4200, 0.6486])
a * b

# 输出:
tensor([[0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486],
        [0.4161, 0.4143, 0.7171, 0.4200, 0.6486]])

$@$:矩阵乘法,不支持broadcast操作

a = torch.tensor(torch.ones(10,1))
b = torch.tensor(torch.rand(5))
a@b
# 输出: RuntimeError: size mismatch, [10 x 1], [5]

a = torch.tensor(torch.ones(10,5))
b = torch.tensor(torch.rand(5))
a@b # shape [10]

a = torch.tensor(torch.ones(10,5))
b = torch.tensor(torch.rand(5,1))
a@b # shape [10,1]

二维矩阵乘法 mm()

不支持broadcast

torch.mm(mat1, mat2, out=None),用于矩阵乘法,同@,

三维带batch的矩阵乘法 bmm()

不支持broadcast

torch.bmm(),在神经网络第一维度通常是batch,batch不参加运算,所以只是第二三维度进行矩阵乘法,

b×n×m``b×m×n--> b×n×n

element-wise mul()

和$*$一样,支持broadcast操作。

多维矩阵乘法 matmul()

支持broadcast

matmul()乘法对参数的后两个维度计算(矩阵乘法),其他维度看作batch,可以进行广播

(1000, 200, 20,10) (200, 10, 20) 最后计算的维度(1000, 200, 20,20)

broadcast

可广播的一对张量满足一下规则:

  • 每个张量至少有一个维度
  • 迭代维度尺寸,从尾部维度开始,存在三种情况,两个张量维度相等;其中一个张量的维度尺寸为1;其中一个张量不存在这个维度

进行广播之后,结果是两个张量对应维度尺寸的较大者