您好,登錄后才能下訂單哦!
torch.Tensor有4種常見的乘法:*, torch.mul, torch.mm, torch.matmul. 本文拋磚引玉,簡單敘述一下這4種乘法的區別,具體使用還是要參照官方文檔。
點乘
a與b做*乘法,原則是如果a與b的size不同,則以某種方式將a或b進行復制,使得復制后的a和b的size相同,然后再將a和b做element-wise的乘法。
下面以*標量和*一維向量為例展示上述過程。
* 標量
Tensor與標量k做*乘法的結果是Tensor的每個元素乘以k(相當于把k復制成與lhs大小相同,元素全為k的Tensor).
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> a * 2 tensor([[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]])
* 一維向量
Tensor與行向量做*乘法的結果是每列乘以行向量對應列的值(相當于把行向量的行復制,成為與lhs維度相同的Tensor). 注意此時要求Tensor的列數與行向量的列數相等。
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> b = torch.Tensor([1,2,3,4]) >>> b tensor([1., 2., 3., 4.]) >>> a * b tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]])
Tensor與列向量做*乘法的結果是每行乘以列向量對應行的值(相當于把列向量的列復制,成為與lhs維度相同的Tensor). 注意此時要求Tensor的行數與列向量的行數相等。
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> b = torch.Tensor([1,2,3]).reshape((3,1)) >>> b tensor([[1.], [2.], [3.]]) >>> a * b tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]])
* 矩陣
經Arsmart在評論區提醒,增補一個矩陣 * 矩陣的例子,感謝Arsmart的熱心評論!
如果兩個二維矩陣A與B做點積A * B,則要求A與B的維度完全相同,即A的行數=B的行數,A的列數=B的列數
>>> a = torch.tensor([[1, 2], [2, 3]]) >>> a * a tensor([[1, 4], [4, 9]])
broadcast
點積是broadcast的。broadcast是torch的一個概念,簡單理解就是在一定的規則下允許高維Tensor和低維Tensor之間的運算。broadcast的概念稍顯復雜,在此不做展開,可以參考官方文檔關于broadcast的介紹. 在torch.matmul里會有關于broadcast的應用的一個簡單的例子。
這里舉一個點積broadcast的例子。在例子中,a是二維Tensor,b是三維Tensor,但是a的維度與b的后兩位相同,那么a和b仍然可以做點積,點積結果是一個和b維度一樣的三維Tensor,運算規則是:若c = a * b
, 則c[i,*,*] = a * b[i, *, *]
,即沿著b的第0維做二維Tensor點積,或者可以理解為運算前將a沿著b的第0維也進行了expand操作,即a = a.expand(b.size()); a * b
。
>>> a = torch.tensor([[1, 2], [2, 3]]) >>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]]) >>> a * b tensor([[[ 1, 4], [ 4, 9]], [[-1, -4], [-4, -9]]]) >>> b * a tensor([[[ 1, 4], [ 4, 9]], [[-1, -4], [-4, -9]]])
其實,上面提到的二維Tensor點積標量、二維Tensor點積行向量,都是發生在高維向量和低維向量之間的,也可以看作是broadcast.
torch.mul
官方文檔關于torch.mul的介紹. 用法與*乘法相同,也是element-wise的乘法,也是支持broadcast的。
下面是幾個torch.mul的例子.
乘標量
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> a * 2 tensor([[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]])
乘行向量
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> b = torch.Tensor([1,2,3,4]) >>> b tensor([1., 2., 3., 4.]) >>> torch.mul(a, b) tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [1., 2., 3., 4.]])
乘列向量
>>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> b = torch.Tensor([1,2,3]).reshape((3,1)) >>> b tensor([[1.], [2.], [3.]]) >>> torch.mul(a, b) tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]])
乘矩陣
例1:二維矩陣 mul 二維矩陣
>>> a = torch.tensor([[1, 2], [2, 3]]) >>> torch.mul(a,a) tensor([[1, 4], [4, 9]])
例2:二維矩陣 mul 三維矩陣(broadcast)
>>> a = torch.tensor([[1, 2], [2, 3]]) >>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]]) >>> torch.mul(a,b) tensor([[[ 1, 4], [ 4, 9]], [[-1, -4], [-4, -9]]])
torch.mm
官方文檔關于torch.mm的介紹. 數學里的矩陣乘法,要求兩個Tensor的維度滿足矩陣乘法的要求.
例子:
>>> a = torch.ones(3,4) >>> b = torch.ones(4,2) >>> torch.mm(a, b) tensor([[4., 4.], [4., 4.], [4., 4.]])
torch.matmul
官方文檔關于torch.matmul的介紹. torch.mm的broadcast版本.
例子:
>>> a = torch.ones(3,4) >>> b = torch.ones(5,4,2) >>> torch.matmul(a, b) tensor([[[4., 4.], [4., 4.], [4., 4.]], [[4., 4.], [4., 4.], [4., 4.]], [[4., 4.], [4., 4.], [4., 4.]], [[4., 4.], [4., 4.], [4., 4.]], [[4., 4.], [4., 4.], [4., 4.]]])
同樣的a和b,使用torch.mm相乘會報錯
>>> torch.mm(a, b) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: matrices expected, got 2D, 3D tensors at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2065
到此這篇關于詳解torch.Tensor的4種乘法的文章就介紹到這了,更多相關torch.Tensor 乘法內容請搜索億速云以前的文章或繼續瀏覽下面的相關文章希望大家以后多多支持億速云!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。