一文弄懂RNN、LSTM 和 GRU單元 !
時間:2024-07-10 來源:華清遠見
RNN
循環神經網絡(Recurrent Neural Network,RNN ),主要處理序列數據,輸入的序列數據可以是連續的、長度不固定的序列數據,也可以是固定的序列數據。循環神經網絡能保持對過去事件和當前事件的記憶,從而可以捕獲長距離樣本之間的關聯信息。循環神經網絡在文字預測、語音識別等領域表現較大優勢。
RNN網絡結構解析
圖1是RNN網絡圖示



RNN存在的問題
存在梯度爆炸和消失的問題,對于長距離的句子的學習效果不好。
反向傳播中,對激活函數進行求導,如果此部分大于1,那么層數增多的時候,最終的求出的梯度更新將以指數形式增加,即發生梯度爆炸,如果此部分小于1,那么隨著層數增多,求出的梯度更新信息將會以指數形式衰減,即發生了梯度消失。
RNN代碼示例
pytorch 簡單代碼示例
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
LSTM
長短期記憶網絡(LSTM,Long Short-Term Memory)是一種時間循環神經網絡,是為了解決一般的RNN(循環神經網絡)存在的長期依賴問題而專門設計出來的。
LSTM網絡結構解析
LSTM網絡結構如圖2所示



LSTM優勢
RNN中只有一個隱藏狀態,LSTM增加了一個元胞狀態單元,其在不同時刻有著可變的連接權重,以解決RNN中梯度消失或爆炸問題。隱藏狀態控制短期記憶,元胞狀態單元控制長期記憶,和配合形成長短期記憶。
LSTM代碼示例
pytorch 簡單代碼示例
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
GRU單元
門控循環單元(gated recurrent unit,GRU)是為了解決循環神經網絡中計算梯度, 以及矩陣連續乘積導致梯度消失或梯度爆炸的問題而提出,GRU更簡單,通常它能夠獲得跟LSTM同等的效果,優勢是計算的速度明顯更快。
GRU單元結構解析
GRU單元結構如圖3所示



GRU優勢
GRU可以取得與LSTM想當甚至更好的性能,且收斂速度更快。
GRU代碼示例
pytorch 簡單代碼示例
rnn = nn.GRU(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)

