北京網(wǎng)站制作西安西安網(wǎng)紅
一、引言
? ? ? ? 引言部分不是論文的重點(diǎn),主要講述了交通預(yù)測的重要性以及一些傳統(tǒng)方法的不足之處。進(jìn)而推出了自己的模型——STGCN。
二、交通預(yù)測與圖卷積
????????
? ? ? ? 第二部分講述了交通預(yù)測中路圖和圖卷積的概念。
? ? ? ? 首先理解道路圖,交通預(yù)測被定義為典型的時(shí)間序列預(yù)測問題,即根據(jù)歷史數(shù)據(jù)預(yù)測未來的交通情況。在該工作中,交通網(wǎng)絡(luò)被建模為圖結(jié)構(gòu),節(jié)點(diǎn)表示監(jiān)控站,邊表示監(jiān)控站之間的連接。
? ? ? ? 其次是圖卷積,由于傳統(tǒng)的卷積操作無法應(yīng)用于圖數(shù)據(jù),作者介紹了兩種擴(kuò)展卷積到圖數(shù)據(jù)的方法,即擴(kuò)展卷積的空間定義、譜圖卷積。這里需要著重注意譜圖卷積,因?yàn)镾TGCN就是采用了傅里葉變換的方法,這里的公式后面也要提到。
三、STGCN結(jié)構(gòu)與原理(重點(diǎn))
????????先看整體網(wǎng)絡(luò)結(jié)構(gòu):
1. 空域卷積塊(圖卷積塊)
1.1 論文
? ? ? ? 這里提到了兩種方法,即切比雪夫多項(xiàng)式和一階近似。對應(yīng)于代碼中的 ChebGraphConv 和 GraphConv。
? ? ? ? 為什么要用切比雪夫多項(xiàng)式計(jì)算?因?yàn)閭鹘y(tǒng)的拉普拉斯計(jì)算起來非常復(fù)雜,使用了切比雪夫多項(xiàng)式以后,通過使用多項(xiàng)式來表示卷積核,圖卷積操作可以在局部化的節(jié)點(diǎn)鄰域內(nèi)進(jìn)行,從而避免了全圖的計(jì)算。
????????一階近似進(jìn)一步簡化了圖卷積的計(jì)算,將圖拉普拉斯算子(公式中的L)的高階近似簡化為一階。這使得圖卷積操作只依賴于當(dāng)前節(jié)點(diǎn)及其直接相鄰節(jié)點(diǎn),計(jì)算復(fù)雜度大幅降低。
1.2 代碼(GraphConvLayer)
? ? ? ? 這里有我加了注釋的代碼:OracleRay/STGCN_pytorch: The PyTorch implementation of STGCN. (github.com)
????????時(shí)空卷積塊和輸出層的定義都在 layers.py 中。
? ? ? ? ChebGraphConv 類的前向傳播方法:
def forward(self, x):# bs, c_in, ts, n_vertex = x.shapex = torch.permute(x, (0, 2, 3, 1)) # 將時(shí)序維度和頂點(diǎn)維度排列到一起,方便后續(xù)的圖卷積計(jì)算if self.Ks - 1 < 0:raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.')elif self.Ks - 1 == 0:x_0 = xx_list = [x_0] # 只使用第 0 階,即不考慮鄰居節(jié)點(diǎn)的影響,直接使用輸入特征elif self.Ks - 1 == 1:x_0 = x# hi 是鄰接矩陣的索引,btij 是輸入張量 x 的索引,bthj為更新后的特征表示x_1 = torch.einsum('hi,btij->bthj', self.gso, x) # 鄰接矩陣gso和輸入特征x進(jìn)行相乘x_list = [x_0, x_1] # 使用第 0 階和第 1 階的節(jié)點(diǎn)信息elif self.Ks - 1 >= 2:x_0 = xx_1 = torch.einsum('hi,btij->bthj', self.gso, x)x_list = [x_0, x_1]for k in range(2, self.Ks): # 根據(jù)切比雪夫多項(xiàng)式的定義來計(jì)算,利用前兩階多項(xiàng)式來構(gòu)建第 k 階多項(xiàng)式x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2])x = torch.stack(x_list, dim=2) # 將所有階的節(jié)點(diǎn)特征堆疊在一起,形成一個(gè)新的張量cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight)if self.bias is not None:cheb_graph_conv = torch.add(cheb_graph_conv, self.bias) # 添加偏置項(xiàng)else:cheb_graph_conv = cheb_graph_conv # 強(qiáng)迫癥return cheb_graph_conv
????????Ks表示空間卷積核大小,gso代表交通預(yù)測圖的鄰接矩陣。在前向傳播的過程中,判斷 Ks - 1 的值的大小。等于0和等于1分別代表著切比雪夫多項(xiàng)式的第 0 階和第 1 階,當(dāng) Ks - 1 大于等于 2 時(shí),不僅需要考慮第 0 階和第 1 階的鄰居節(jié)點(diǎn),還會通過遞歸關(guān)系計(jì)算更高階的鄰居節(jié)點(diǎn)信息。
? ? ? ? 構(gòu)建 k 階多項(xiàng)式時(shí),需要鄰接矩陣 gso 和 (k - 1) 階相乘,然后與?(k - 2) 階相減。這樣就滿足了切比雪夫多項(xiàng)式的要求,即:
? ? ? ? ?最后把得到的值乘以權(quán)重(weight),如果有偏置項(xiàng),再加偏置值(bias)即可。
????????GraphConv 類與此類似。
? ? ? ? 然后在 GraphConvLayer 中就可選擇到底是使用切比雪夫多項(xiàng)式(ChebGraphConv)還是一階近似(GraphConv)。
def forward(self, x):x_gc_in = self.align(x)if self.graph_conv_type == 'cheb_graph_conv':x_gc = self.cheb_graph_conv(x_gc_in)elif self.graph_conv_type == 'graph_conv':x_gc = self.graph_conv(x_gc_in)x_gc = x_gc.permute(0, 3, 1, 2)x_gc_out = torch.add(x_gc, x_gc_in) # 殘差連接return x_gc_out
2. 時(shí)域卷積塊
2.1 論文
???????時(shí)域卷積塊最關(guān)鍵的兩個(gè)內(nèi)容就是因果卷積和門控機(jī)制GLU,不過這兩個(gè)并不是這篇論文里提出來的。
? ? ? ? 因果卷積(代碼中 CausalConv 類)是時(shí)域卷積網(wǎng)絡(luò)模型(TCN)中的重要內(nèi)容。在本篇論文中,使用了1-D因果卷積來確保當(dāng)前時(shí)刻只依賴于過去的輸入數(shù)據(jù),這樣未來的信息在當(dāng)前時(shí)刻就不會被使用。
? ? ? ? GLU是2016年由Yann N. Dauphin在論文《Language Modeling with Gated Convolutional Networks》中提出的。時(shí)域卷積塊采用門控線性單元(GLU)作為非線性激活函數(shù)可以控制哪些輸入是重要的,而不是所有信息都平等對待,這樣有助于在時(shí)間序列中提取關(guān)鍵特征。
? ? ? ? 殘差連接、瓶頸策略、并行訓(xùn)練等作者只是提了一嘴,不是重點(diǎn)。
2.2 代碼(TemporalConvLayer)
? ? ? ? 首先看因果卷積的代碼:
class CausalConv2d(nn.Conv2d):def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1,bias=True):kernel_size = nn.modules.utils._pair(kernel_size) # 卷積核大小,表示對多少個(gè)像素(或特征)進(jìn)行卷積。stride = nn.modules.utils._pair(stride) # 步長,控制卷積核滑動的步幅dilation = nn.modules.utils._pair(dilation) # dilation:膨脹系數(shù),控制采樣間隔,用來擴(kuò)大卷積核的感受野if enable_padding == True: # 啟用零填充self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]else:self.__padding = 0self.left_padding = nn.modules.utils._pair(self.__padding)super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0,dilation=dilation, groups=groups, bias=bias)def forward(self, input):if self.__padding != 0:# F.pad() 函數(shù)用于在高度和寬度方向上添加填充input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))result = super(CausalConv2d, self).forward(input)return result
? ? ? ? 在初始化部分,如果需要零填充,則要先計(jì)算填充量。填充量的計(jì)算公式為:(kernel_size[i] - 1) * dilation[i] 。
? ? ? ? 其中 kernel_size[i] - 1 表示在每個(gè)維度(高度和寬度)上,卷積核“超出”當(dāng)前步長的部分。例如一個(gè) 3 × 3 大小的卷積核,就會有2個(gè)位置影響當(dāng)前位置(左一個(gè)右一個(gè))。而膨脹系數(shù) dilation[i] 則意味著卷積核中元素之間有多少間隔。最后相乘即可求出需要填充幾個(gè)位置。
? ? ? ? 最后在前向傳播過程中用 F.pad() 函數(shù)就可以實(shí)現(xiàn)左填充。
? ? ? ? 然后再看時(shí)域卷積塊(TemporalConvLayer)的前向傳播代碼:
def forward(self, x):x_in = self.align(x)[:, :, self.Kt - 1:, :] # 對其輸入通道數(shù)x_causal_conv = self.causal_conv(x) # 進(jìn)行因果卷積if self.act_func == 'glu' or self.act_func == 'gtu':x_p = x_causal_conv[:, : self.c_out, :, :] # 分割出前半部分x_q = x_causal_conv[:, -self.c_out:, :, :] # 分割出后半部分if self.act_func == 'glu':# 通過門控機(jī)制選擇性保留某些時(shí)間步的特征,這對時(shí)間序列建模非常有效x = torch.mul((x_p + x_in), torch.sigmoid(x_q)) # 對 x_p 和輸入的對齊結(jié)果 x_in 進(jìn)行線性加和,并與 x_q 的 sigmoid 值進(jìn)行點(diǎn)乘else:# tanh(x_p + x_in) ⊙ sigmoid(x_q)x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q)) # 使用 tanh 代替線性加和,具有非線性變換的特性elif self.act_func == 'relu':x = self.relu(x_causal_conv + x_in)elif self.act_func == 'silu':x = self.silu(x_causal_conv + x_in)else:raise NotImplementedError(f'ERROR: The activation function {self.act_func} is not implemented.')return x
? ? ? ? 這里的代碼只干了兩件事:因果卷積和選擇激活函數(shù),與論文中的時(shí)域卷積塊的思想大致相同。這里著重理解這行代碼:
x = torch.mul((x_p + x_in), torch.sigmoid(x_q))
? ? ? ? (x_p + x_in)?將前半部分 x_p 與輸入 x_in 的對齊結(jié)果進(jìn)行線性加和,表示對主要特征的組合。sigmoid(x_q) 將后半部分 x_q?通過 sigmoid 函數(shù)轉(zhuǎn)化為 0 到 1 之間的值,作為控制門。最后用⊙符號逐元素相乘,GLU就能決定 x_p 能通過多少信息。
? ? ? ? 默認(rèn)使用GLU激活函數(shù),其他 if-else 語句中的激活函數(shù)不使用。
3. 時(shí)空卷積塊
3.1 論文
? ? ? ? 前面知道,兩個(gè)時(shí)域卷積塊 + 一個(gè)空域卷積塊 = 一個(gè)時(shí)空卷積塊。而且是兩個(gè)時(shí)域卷積塊夾著一個(gè)空域卷積塊的三明治結(jié)構(gòu)。這種設(shè)計(jì)可以同時(shí)處理交通網(wǎng)絡(luò)中的 時(shí)間依賴 和 空間依賴,即模型可以同時(shí)從時(shí)序信息和圖結(jié)構(gòu)中提取重要特征。
????????中間的圖卷積層負(fù)責(zé)從圖結(jié)構(gòu)(如道路網(wǎng)絡(luò))中提取空間特征。通過使用前面提到的圖卷積方法(如切比雪夫多項(xiàng)式近似或一階近似),可以高效地捕捉交通站點(diǎn)之間的連接關(guān)系?。
????????上下兩個(gè)時(shí)間卷積層負(fù)責(zé)提取時(shí)間依賴特征。通過因果卷積的方式,可以確保預(yù)測時(shí)只使用當(dāng)前時(shí)刻及之前的交通數(shù)據(jù),避免未來信息泄露?。
3.2 代碼(STConvBlock)
class STConvBlock(nn.Module):def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, act_func, graph_conv_type, gso, bias, droprate):super(STConvBlock, self).__init__()# “三明治”結(jié)構(gòu):兩個(gè)時(shí)域卷積塊,一個(gè)空域卷積塊self.tmp_conv1 = TemporalConvLayer(Kt, last_block_channel, channels[0], n_vertex, act_func)self.graph_conv = GraphConvLayer(graph_conv_type, channels[0], channels[1], Ks, gso, bias)self.tmp_conv2 = TemporalConvLayer(Kt, channels[1], channels[2], n_vertex, act_func)self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12) # 歸一化:緩解梯度消失或梯度爆炸問題self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate) # 正則化:dropout率為0.5def forward(self, x):x = self.tmp_conv1(x)x = self.graph_conv(x)x = self.relu(x)x = self.tmp_conv2(x)x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)x = self.dropout(x)return x
? ? ? ? 在初始化函數(shù)中,先定義“三明治”結(jié)構(gòu),使用之前寫好的時(shí)域卷積塊和圖卷積塊。然后確定歸一化方法,激活函數(shù),正則化方法。
4. 輸出層
4.1 論文
? ? ? ? 論文中的輸出層是由一個(gè)時(shí)域卷積塊和一個(gè)全連接層組成的。
? ? ? ? 當(dāng)最后一個(gè)時(shí)空卷積塊處理完數(shù)據(jù)之后,輸出的是一個(gè)三維張量(M × n × C),其中M是時(shí)間步數(shù)(例如,過去60分鐘的交通數(shù)據(jù)),n是交通網(wǎng)絡(luò)中的節(jié)點(diǎn)數(shù)(即監(jiān)測站或路段數(shù)),C是特征通道數(shù)。而論文使用了一個(gè)時(shí)域卷積層將這些特征進(jìn)一步壓縮成一個(gè)單步的時(shí)間預(yù)測輸出。這意味著,時(shí)間卷積會提取多時(shí)間步數(shù)據(jù)中的關(guān)鍵信息,并最終輸出一個(gè)代表未來某一時(shí)刻(如未來15分鐘)的預(yù)測結(jié)果。
????????接下來,在時(shí)間卷積之后,論文使用了一個(gè)全連接層將卷積層輸出的特征張量映射到一個(gè)單一的輸出值,通常是每個(gè)節(jié)點(diǎn)的交通狀態(tài)(如車速或流量),最后生成預(yù)測結(jié)果 v。
4.2 代碼(OutputBlock)
class OutputBlock(nn.Module):def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, act_func, bias, droprate):super(OutputBlock, self).__init__()self.tmp_conv1 = TemporalConvLayer(Ko, last_block_channel, channels[0], n_vertex, act_func)self.fc1 = nn.Linear(in_features=channels[0], out_features=channels[1], bias=bias)self.fc2 = nn.Linear(in_features=channels[1], out_features=end_channel, bias=bias)self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12) # 歸一化self.relu = nn.ReLU()self.dropout = nn.Dropout(p=droprate) # 正則化def forward(self, x):x = self.tmp_conv1(x)x = self.tc1_ln(x.permute(0, 2, 3, 1))x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x).permute(0, 3, 1, 2) # 負(fù)責(zé)將時(shí)空特征映射為最終的預(yù)測值return x
?????????代碼部分使用了一個(gè)時(shí)域卷積塊和兩個(gè)全連接層。這是為什么?這樣的設(shè)計(jì)雖然與論文描述的輸出層結(jié)構(gòu)有所不同,但增加了額外的全連接層是為了增強(qiáng)模型的表達(dá)能力和預(yù)測精度。
? ? ? ? 第一個(gè)全連接層用于對時(shí)域卷積輸出的特征進(jìn)行降維或變換。通過這個(gè)全連接層,模型可以將高維度的時(shí)空特征壓縮或轉(zhuǎn)化為新的特征表示,使得模型能夠更好地抽象復(fù)雜的關(guān)系。第二個(gè)全連接層才用于最終的輸出,即生成最后的預(yù)測結(jié)果。
5. 其他代碼
5.1 models.py
? ? ? ? 這個(gè)python文件里主要是對整個(gè)STGCN模型進(jìn)行整合,一共有兩個(gè)類,分別是 STGCNChebGraphConv 和 STGCNGraphConv 。這分別代表著使用切比雪夫多項(xiàng)式還是一階近似。
????????其中的大多數(shù)代碼都是對 layers.py 中的函數(shù)方法調(diào)用,傳參。有一行代碼需要理解:
Ko = args.n_his - (len(blocks) - 3) * 2 * (args.Kt - 1)
????????這句代碼的作用是計(jì)算經(jīng)過多個(gè)時(shí)空卷積塊處理后,保留下來的時(shí)間維度的大小。
- args.n_his:是輸入數(shù)據(jù)的時(shí)間維度大小,通常指輸入的歷史時(shí)間步數(shù)。
- len(blocks) - 3:blocks表示 STGCN 模型中不同層的配置的列表。它的長度再減3是為了去掉輸出層相關(guān)的三層結(jié)構(gòu)(
TNFF
,即兩個(gè)全連接層和最后的時(shí)序處理層),僅關(guān)注時(shí)空卷積部分。- 2 * (args.Kt - 1):Kt 是每個(gè)時(shí)空卷積塊中時(shí)間卷積核的大小。它的大小再減1是因?yàn)槊看尉矸e操作后時(shí)間維度會 - 1(時(shí)間卷積是滑動窗口的形式)。而每個(gè)時(shí)空卷積塊中有兩個(gè)時(shí)間卷積層,所以總的時(shí)間維度減少量會再乘2。
? ? ? ? 因此,對輸入的時(shí)間維度 n_his,每經(jīng)過一個(gè)時(shí)空卷積塊,時(shí)間維度會減少 2 * (args.Kt - 1)。有 len(blocks) - 3 個(gè)這樣的時(shí)空卷積塊,因此總的時(shí)間維度減少的量是 (len(blocks) - 3) * 2 * (args.Kt - 1)。
5.2 main.py
? ? ? ? main 函數(shù)是對整個(gè)代碼運(yùn)行環(huán)境的配置,包括配置環(huán)境變量,設(shè)置命令行參數(shù),數(shù)據(jù)類型轉(zhuǎn)換,確定模型、優(yōu)化器等等。這些代碼在其他網(wǎng)絡(luò)模型同樣受用,大差不差。