wordpress黑桃錘擊河北seo網(wǎng)絡(luò)推廣
強(qiáng)化學(xué)習(xí)筆記之【DDPG算法】
文章目錄
- 強(qiáng)化學(xué)習(xí)筆記之【DDPG算法】
- 前言:
- 原論文偽代碼
- DDPG算法
- DDPG 中的四個(gè)網(wǎng)絡(luò)
- 代碼核心更新公式
前言:
本文為強(qiáng)化學(xué)習(xí)筆記第二篇,第一篇講的是Q-learning和DQN
就是因?yàn)镈DPG引入了Actor-Critic模型,所以比DQN多了兩個(gè)網(wǎng)絡(luò),網(wǎng)絡(luò)名字功能變了一下,其它的就是軟更新之類(lèi)的小改動(dòng)而已
本文初編輯于2024.10.6
CSDN主頁(yè):https://blog.csdn.net/rvdgdsva
博客園主頁(yè):https://www.cnblogs.com/hassle
博客園本文鏈接:

原論文偽代碼

- 上述代碼為DDPG原論文中的偽代碼
DDPG算法
需要先看:
Deep Reinforcement Learning (DRL) 算法在 PyTorch 中的實(shí)現(xiàn)與應(yīng)用【DDPG部分】【沒(méi)有在選擇一個(gè)新的動(dòng)作的時(shí)候,給policy函數(shù)返回的動(dòng)作值增加一個(gè)噪音】【critic網(wǎng)絡(luò)與下面不同】
深度強(qiáng)化學(xué)習(xí)筆記——DDPG原理及實(shí)現(xiàn)(pytorch)【DDPG偽代碼部分】【這個(gè)跟上面的一樣沒(méi)有加噪音】【critic網(wǎng)絡(luò)與上面不同】
【深度強(qiáng)化學(xué)習(xí)】(4) Actor-Critic 模型解析,附Pytorch完整代碼【選看】【Actor-Critic理論部分】
如果需要給policy函數(shù)返回的動(dòng)作值增加一個(gè)噪音,實(shí)現(xiàn)如下

def select_action(self, state, noise_std=0.1):state = torch.FloatTensor(state.reshape(1, -1))action = self.actor(state).cpu().data.numpy().flatten()# 添加噪音,上面兩個(gè)文檔的代碼都沒(méi)有這個(gè)步驟noise = np.random.normal(0, noise_std, size=action.shape)action = action + noisereturn action
DDPG 中的四個(gè)網(wǎng)絡(luò)

注意!!!這個(gè)圖只展示了Critic網(wǎng)絡(luò)的更新,沒(méi)有展示Actor網(wǎng)絡(luò)的更新
- Actor 網(wǎng)絡(luò)(策略網(wǎng)絡(luò)):
- 作用:決定給定狀態(tài) ss 時(shí),應(yīng)該采取的動(dòng)作 a=π(s)a=π(s),目標(biāo)是找到最大化未來(lái)回報(bào)的策略。
- 更新:基于 Critic 網(wǎng)絡(luò)提供的 Q 值更新,以最大化 Critic 估計(jì)的 Q 值。
- Target Actor 網(wǎng)絡(luò)(目標(biāo)策略網(wǎng)絡(luò)):
- 作用:為 Critic 網(wǎng)絡(luò)提供更新目標(biāo),目的是讓目標(biāo) Q 值的更新更為穩(wěn)定。
- 更新:使用軟更新,緩慢向 Actor 網(wǎng)絡(luò)靠近。
- Critic 網(wǎng)絡(luò)(Q 網(wǎng)絡(luò)):
- 作用:估計(jì)當(dāng)前狀態(tài) ss 和動(dòng)作 aa 的 Q 值,即 Q(s,a)Q(s,a),為 Actor 提供優(yōu)化目標(biāo)。
- 更新:通過(guò)最小化與目標(biāo) Q 值的均方誤差進(jìn)行更新。
- Target Critic 網(wǎng)絡(luò)(目標(biāo) Q 網(wǎng)絡(luò)):
- 作用:生成 Q 值更新的目標(biāo),使得 Q 值更新更為穩(wěn)定,減少振蕩。
- 更新:使用軟更新,緩慢向 Critic 網(wǎng)絡(luò)靠近。
大白話(huà)解釋:
? 1、DDPG實(shí)例化為actor,輸入state輸出action
? 2、DDPG實(shí)例化為actor_target
? 3、DDPG實(shí)例化為critic_target,輸入next_state和actor_target(next_state)經(jīng)DQN計(jì)算輸出target_Q
? 4、DDPG實(shí)例化為critic,輸入state和action輸出current_Q,輸入state和actor(state)【這個(gè)參數(shù)需要注意,不是action】經(jīng)負(fù)均值計(jì)算輸出actor_loss
? 5、current_Q 和target_Q進(jìn)行critic的參數(shù)更新
? 6、actor_loss進(jìn)行actor的參數(shù)更新
action實(shí)際上是batch_action,state實(shí)際上是batch_state,而batch_action != actor(batch_state)
因?yàn)閍ctor是頻繁更新的,而采樣是隨機(jī)采樣,不是所有batch_action都能隨著actor的更新而同步更新
Critic網(wǎng)絡(luò)的更新是一發(fā)而動(dòng)全身的,相比于Actor網(wǎng)絡(luò)的更新要復(fù)雜要重要許多
代碼核心更新公式
t a r g e t  ̄ Q = c r i t i c  ̄ t a r g e t ( n e x t  ̄ s t a t e , a c t o r  ̄ t a r g e t ( n e x t  ̄ s t a t e ) ) t a r g e t  ̄ Q = r e w a r d + ( 1 ? d o n e ) × g a m m a × t a r g e t  ̄ Q . d e t a c h ( ) target\underline{~}Q = critic\underline{~}target(next\underline{~}state, actor\underline{~}target(next\underline{~}state)) \\target\underline{~}Q = reward + (1 - done) \times gamma \times target\underline{~}Q.detach() target??Q=critic??target(next??state,actor??target(next??state))target??Q=reward+(1?done)×gamma×target??Q.detach()

- 上述代碼與偽代碼對(duì)應(yīng),意為計(jì)算預(yù)測(cè)Q值
c r i t i c  ̄ l o s s = M S E L o s s ( c r i t i c ( s t a t e , a c t i o n ) , t a r g e t  ̄ Q ) c r i t i c  ̄ o p t i m i z e r . z e r o  ̄ g r a d ( ) c r i t i c  ̄ l o s s . b a c k w a r d ( ) c r i t i c  ̄ o p t i m i z e r . s t e p ( ) critic\underline{~}loss = MSELoss(critic(state, action), target\underline{~}Q) \\critic\underline{~}optimizer.zero\underline{~}grad() \\critic\underline{~}loss.backward() \\critic\underline{~}optimizer.step() critic??loss=MSELoss(critic(state,action),target??Q)critic??optimizer.zero??grad()critic??loss.backward()critic??optimizer.step()

- 上述代碼與偽代碼對(duì)應(yīng),意為使用均方誤差損失函數(shù)更新Critic
a c t o r  ̄ l o s s = ? c r i t i c ( s t a t e , a c t o r ( s t a t e ) ) . m e a n ( ) a c t o r  ̄ o p t i m i z e r . z e r o  ̄ g r a d ( ) a c t o r  ̄ l o s s . b a c k w a r d ( ) a c t o r  ̄ o p t i m i z e r . s t e p ( ) actor\underline{~}loss = -critic(state,actor(state)).mean() \\actor\underline{~}optimizer.zero\underline{~}grad() \\ actor\underline{~}loss.backward() \\ actor\underline{~}optimizer.step() actor??loss=?critic(state,actor(state)).mean()actor??optimizer.zero??grad()actor??loss.backward()actor??optimizer.step()


- 上述代碼與偽代碼對(duì)應(yīng),意為使用確定性策略梯度更新Actor
c r i t i c  ̄ t a r g e t . p a r a m e t e r s ( ) . d a t a = ( t a u × c r i t i c . p a r a m e t e r s ( ) . d a t a + ( 1 ? t a u ) × c r i t i c  ̄ t a r g e t . p a r a m e t e r s ( ) . d a t a ) a c t o r  ̄ t a r g e t . p a r a m e t e r s ( ) . d a t a = ( t a u × a c t o r . p a r a m e t e r s ( ) . d a t a + ( 1 ? t a u ) × a c t o r  ̄ t a r g e t . p a r a m e t e r s ( ) . d a t a ) critic\underline{~}target.parameters().data=(tau \times critic.parameters().data + (1 - tau) \times critic\underline{~}target.parameters().data) \\ actor\underline{~}target.parameters().data=(tau \times actor.parameters().data + (1 - tau) \times actor\underline{~}target.parameters().data) critic??target.parameters().data=(tau×critic.parameters().data+(1?tau)×critic??target.parameters().data)actor??target.parameters().data=(tau×actor.parameters().data+(1?tau)×actor??target.parameters().data)

- 上述代碼與偽代碼對(duì)應(yīng),意為使用策略梯度更新目標(biāo)網(wǎng)絡(luò)
Actor和Critic的角色:
- Actor:負(fù)責(zé)選擇動(dòng)作。它根據(jù)當(dāng)前的狀態(tài)輸出一個(gè)確定性動(dòng)作。
- Critic:評(píng)估Actor的動(dòng)作。它通過(guò)計(jì)算狀態(tài)-動(dòng)作值函數(shù)(Q值)來(lái)評(píng)估給定狀態(tài)和動(dòng)作的價(jià)值。
更新邏輯:
- Critic的更新:
- 使用經(jīng)驗(yàn)回放緩沖區(qū)(Experience Replay)從中采樣一批經(jīng)驗(yàn)(狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)、下一個(gè)狀態(tài))。
- 計(jì)算目標(biāo)Q值:使用目標(biāo)網(wǎng)絡(luò)(critic_target)來(lái)估計(jì)下一個(gè)狀態(tài)的Q值(target_Q),并結(jié)合當(dāng)前的獎(jiǎng)勵(lì)。
- 使用均方誤差損失函數(shù)(MSELoss)來(lái)更新Critic的參數(shù),使得預(yù)測(cè)的Q值(target_Q)與當(dāng)前Q值(current_Q)盡量接近。
- Actor的更新:
- 根據(jù)當(dāng)前的狀態(tài)(state)從Critic得到Q值的梯度(即對(duì)Q值相對(duì)于動(dòng)作的偏導(dǎo)數(shù))。
- 使用確定性策略梯度(DPG)的方法來(lái)更新Actor的參數(shù),目標(biāo)是最大化Critic評(píng)估的Q值。
個(gè)人理解:
DQN算法是將q_network中的參數(shù)每n輪一次復(fù)制到target_network里面
DDPG使用系數(shù) τ \tau τ來(lái)更新參數(shù),將學(xué)習(xí)到的參數(shù)更加soft地拷貝給目標(biāo)網(wǎng)絡(luò)
DDPG采用了actor-critic網(wǎng)絡(luò),所以比DQN多了兩個(gè)網(wǎng)絡(luò)