國內(nèi)做網(wǎng)站費(fèi)用seo建站網(wǎng)絡(luò)公司
一、概念
????????門控循環(huán)單元(Gated Recurrent Unit,GRU)是一種改進(jìn)的循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),由Cho等人在2014年提出。GRU是LSTM的簡化版本,通過減少門的數(shù)量和簡化結(jié)構(gòu),保留了LSTM的長時(shí)間依賴捕捉能力,同時(shí)提高了計(jì)算效率。GRU通過引入兩個(gè)門(重置門和更新門)來控制信息的流動。與LSTM不同,GRU沒有單獨(dú)的細(xì)胞狀態(tài),而是將隱藏狀態(tài)直接作為信息傳遞的載體,因此結(jié)構(gòu)更簡單,計(jì)算效率更高。
二、核心算法
? ? ? ? 令為時(shí)間步 t 的輸入向量,
為前一個(gè)時(shí)間步的隱藏狀態(tài)向量,
為當(dāng)前時(shí)間步的隱藏狀態(tài)向量,
為當(dāng)前時(shí)間步的重置門向量,
為當(dāng)前時(shí)間步的更新門向量,
為當(dāng)前時(shí)間步的候選隱藏狀態(tài)向量,
分別為各門的權(quán)重矩陣,
為偏置向量,
為sigmoid激活函數(shù),tanh為tanh激活函數(shù),*為元素級乘法。
1、重置門
????????重置門控制前一個(gè)時(shí)間步的隱藏狀態(tài)對當(dāng)前時(shí)間步的影響。通過sigmoid激活函數(shù),重置門的輸出在0到1之間,表示前一個(gè)隱藏狀態(tài)元素被保留的比例。
2、更新門
????????更新門控制前一個(gè)時(shí)間步的隱藏狀態(tài)和當(dāng)前時(shí)間步的候選隱藏狀態(tài)的混合比例。通過sigmoid激活函數(shù),更新門的輸出在0到1之間,表示前一個(gè)隱藏狀態(tài)元素被保留的比例。
3、候選隱藏狀態(tài)
????????候選隱藏狀態(tài)結(jié)合當(dāng)前輸入和前一個(gè)時(shí)間步的隱藏狀態(tài)生成。重置門的輸出與前一個(gè)隱藏狀態(tài)相乘,表示保留的舊信息。然后與當(dāng)前輸入一起通過tanh激活函數(shù)生成候選隱藏狀態(tài)。
4、隱藏狀態(tài)更新
????????隱藏狀態(tài)結(jié)合更新門的結(jié)果進(jìn)行更新。更新門的輸出與前一個(gè)隱藏狀態(tài)相乘,表示保留的舊信息。更新門的補(bǔ)數(shù)與候選隱藏狀態(tài)相乘,表示寫入的新信息。兩者相加得到當(dāng)前時(shí)間步的隱藏狀態(tài)。
三、python實(shí)現(xiàn)
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 設(shè)置隨機(jī)種子
torch.manual_seed(0)
np.random.seed(0)# 生成正弦波數(shù)據(jù)
timesteps = 1000
sin_wave = np.array([np.sin(2 * np.pi * i / timesteps) for i in range(timesteps)])# 創(chuàng)建數(shù)據(jù)集
def create_dataset(data, time_step=1):dataX, dataY = [], []for i in range(len(data) - time_step - 1):a = data[i:(i + time_step)]dataX.append(a)dataY.append(data[i + time_step])return np.array(dataX), np.array(dataY)time_step = 10
X, y = create_dataset(sin_wave, time_step)# 數(shù)據(jù)預(yù)處理
X = X.reshape(X.shape[0], time_step, 1)
y = y.reshape(-1, 1)# 轉(zhuǎn)換為Tensor
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)# 劃分訓(xùn)練集和測試集
train_size = int(len(X) * 0.7)
test_size = len(X) - train_size
trainX, testX = X[:train_size], X[train_size:]
trainY, testY = y[:train_size], y[train_size:]# 定義RNN模型
class GRUModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUModel, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(1, x.size(0), self.hidden_size)out, _ = self.gru(x, h0)out = self.fc(out[:, -1, :])return outinput_size = 1
hidden_size = 50
output_size = 1
model = GRUModel(input_size, hidden_size, output_size)# 定義損失函數(shù)和優(yōu)化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 訓(xùn)練模型
num_epochs = 50
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(trainX)loss = criterion(outputs, trainY)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 預(yù)測
model.eval()
train_predict = model(trainX)
test_predict = model(testX)
train_predict = train_predict.detach().numpy()
test_predict = test_predict.detach().numpy()# 繪制結(jié)果
plt.figure(figsize=(10, 6))
plt.plot(sin_wave, label='Original Data')
plt.plot(np.arange(time_step, time_step + len(train_predict)), train_predict, label='Training Predict')
plt.plot(np.arange(time_step + len(train_predict), time_step + len(train_predict) + len(test_predict)), test_predict, label='Test Predict')
plt.legend()
plt.show()
四、總結(jié)
????????GRU的結(jié)構(gòu)比LSTM更簡單,只有兩個(gè)門(重置門和更新門),沒有單獨(dú)的細(xì)胞狀態(tài)。這使得GRU的計(jì)算復(fù)雜度較低,訓(xùn)練和推理速度更快。通過引入重置門和更新門,GRU也有效地解決了標(biāo)準(zhǔn)RNN在處理長序列時(shí)的梯度消失和梯度爆炸問題。然而,在需要更精細(xì)的門控制和信息流動的任務(wù)中,LSTM的性能可能優(yōu)于GRU。因此在我們實(shí)際的建模過程中,可以根據(jù)數(shù)據(jù)特點(diǎn)選擇合適的RNN系列模型,并沒有哪個(gè)模型能在所有任務(wù)中都具有優(yōu)勢。