有什么做節(jié)能報(bào)告的網(wǎng)站福州網(wǎng)站建設(shè)
PyTorch
中的張量具有和NumPy
相同的廣播特性,允許不同形狀的張量之間進(jìn)行計(jì)算。
廣播的實(shí)質(zhì)特性,其實(shí)是低維向量映射到高維之后,相同位置再進(jìn)行相加。我們重點(diǎn)要學(xué)會(huì)的就是低維向量如何向高維向量進(jìn)行映射。
相同形狀的張量計(jì)算
雖然我們覺(jué)得不同形狀之間的張量計(jì)算才是廣播,但其實(shí)相同形狀的張量計(jì)算本質(zhì)上也是廣播。
t1 = torch.arange(3)
t1
# tensor([0, 1, 2])# 對(duì)應(yīng)位置元素相加
t1 + t1
# tensor([0, 2, 4])
與Python對(duì)比
如果兩個(gè)list
相加,結(jié)果是什么?
a = [0, 1, 2]
a + a
# [0, 1, 2, 0, 1, 2]
不同形狀的張量計(jì)算
廣播的特性是不同形狀的張量進(jìn)行計(jì)算時(shí),一個(gè)或多個(gè)張量通過(guò)隱式轉(zhuǎn)化成相同形狀的兩個(gè)張量,從而完成計(jì)算。
但并非任意兩個(gè)不同形狀的張量都能進(jìn)行廣播,因此我們要掌握廣播隱式轉(zhuǎn)化的核心依據(jù)。
2.1 標(biāo)量和任意形狀的張量
標(biāo)量(零維張量)可以和任意形狀的張量進(jìn)行計(jì)算,計(jì)算過(guò)程就是標(biāo)量和張量的每一個(gè)元素進(jìn)行計(jì)算。
# 標(biāo)量與一維向量
t1 = torch.arange(3)
# tensor([0, 1, 2])t1 + 1 # 等效于t1 + torch.tensor(1)
# tensor([1, 2, 3])
# 標(biāo)量與二維向量
t2 = torch.zeros((3, 4))
t2 + 1 # 等效于t2 + torch.tensor(1)
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
2.2 相同維度,不同形狀張量之間的計(jì)算
我們以t2
為例來(lái)探討相同維度、不同形狀的張量之間的廣播規(guī)則。
t2 = torch.zeros(3, 4)
t2
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]])t21 = torch.ones(1, 4)
t21
# tensor([[1., 1., 1., 1.]])
它們都是二維矩陣,t21
的形狀是1×4
,t2
的形狀是3×4
,它們?cè)诘谝粋€(gè)分量上取值不同,但該分量上t21
取值為1,因此可以進(jìn)行廣播計(jì)算:
t2 + t21
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
而t2和t21的實(shí)際計(jì)算過(guò)程如下:
可理解為t21
的一行與t2
的三行分別進(jìn)行了相加。而底層原理為t21
的形狀由1×4
拓展成了t2
的3×4
,然后二者對(duì)應(yīng)位置進(jìn)行了相加。
t22 = torch.ones(3, 1)
t22
# tensor([[1.],
# [1.],
# [1.]])t2 + t22
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
同理,t22+t2
與t21+t2
結(jié)果相同。如果矩陣的兩個(gè)維度都不相同呢?
t23 = torch.arange(3).reshape(3, 1)
t23
# tensor([[0],
# [1],
# [2]])t24 = torch.arange(3).reshape(1, 3)
# tensor([[0, 1, 2]])t23 + t24
# tensor([[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4]])
此時(shí),t23
的形狀是3×1,而t24
的形狀是1×3
,二者的形狀在兩個(gè)份量上均不同,但都有1存在,因此可以廣播:
如果兩個(gè)張量的維度對(duì)應(yīng)數(shù)不同且都不為1,那么就無(wú)法廣播。
t25 = torch.ones(2, 4)
# t2的shape為3×4
t2 + t25
# RuntimeError
高維張量的廣播
高維張量的廣播原理與低維張量的廣播原理一致:
t3 = torch.zeros(2, 3, 4)
t3
# tensor([[[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]],# [[0., 0., 0., 0.],
# [0., 0., 0., 0.],
# [0., 0., 0., 0.]]])t31 = torch.ones(2, 3, 1)
t31
# tensor([[[1.],
# [1.],
# [1.]],# [[1.],
# [1.],
# [1.]]])t3+t31
# tensor([[[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]],# [[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]]])
總結(jié)
維度相同時(shí),如果對(duì)應(yīng)分量不同,但有一個(gè)為1,就可以廣播。
不同維度計(jì)算中的廣播
對(duì)于不同維度的張量,我們首先可以將低維的張量升維,然后依據(jù)相同維度不同形狀的張量廣播規(guī)則進(jìn)行廣播。
低維向量的升維也非常簡(jiǎn)單,只需將更高維度方向的形狀填充為1即可:
# 創(chuàng)建一個(gè)二維向量
t2 = torch.arange(4).reshape(2, 2)
t2
# tensor([[0, 1],
# [2, 3]])# 創(chuàng)建一個(gè)三維向量
t3 = torch.zeros(3, 2, 2)
t3t2 + t3
# tensor([[[0., 1.],
# [2., 3.]],# [[0., 1.],
# [2., 3.]],# [[0., 1.],
# [2., 3.]]])
t3
和t2
的相加,就相當(dāng)于1×2×2
和3×2×2
的兩個(gè)張量進(jìn)行計(jì)算,廣播規(guī)則與低維張量一致。
相信看完本節(jié),你已經(jīng)充分掌握了廣播機(jī)制的運(yùn)算規(guī)則:
- 維度相同時(shí),如果對(duì)應(yīng)分量不同,但有一個(gè)為1,就可以廣播
- 維度不同時(shí),只需將低維向量的更高維度方向的形狀填充為1即可
Pytorch張量操作大全:
Pytorch使用教學(xué)1-Tensor的創(chuàng)建
Pytorch使用教學(xué)2-Tensor的維度
Pytorch使用教學(xué)3-特殊張量的創(chuàng)建與類(lèi)型轉(zhuǎn)化
Pytorch使用教學(xué)4-張量的索引
Pytorch使用教學(xué)5-視圖view與reshape的區(qū)別
Pytorch使用教學(xué)6-張量的分割與合并
Pytorch使用教學(xué)7-張量的廣播
Pytorch使用教學(xué)8-張量的科學(xué)運(yùn)算
Pytorch使用教學(xué)9-張量的線性代數(shù)運(yùn)算
Pytorch使用教學(xué)10-張量操作方法大總結(jié)