news 2026/4/23 15:50:09

强化学习QAC求最优策略的代码实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
强化学习QAC求最优策略的代码实现

理论基础:

注意:

1. get_policy_by_state_value_net() 是额外写的一个基于Q的贪心策略,不属于QAC算法,得到的策略不一定是最优的,与get_policy_by_policy_net表现一致是偶然现象。

2. 图片中的伪代码并没有说要生成多条episode,但这个为了保证每个(s,a)pair都能被访问到,会生成多条episode。

代码可运行:

import numpy as np import torch from torch import nn from env import GridWorldEnv from utils import drow_policy class QAC(object): def __init__(self, env: GridWorldEnv, gamma=0.9, lr_actor=1e-2, lr_critic=1e-2): self.env = env self.action_space_size = self.env.num_actions self.state_space_size = self.env.num_states self.gamma = gamma self.pnet = nn.Sequential( # policy_net nn.Linear(2, 16), # s -> Π(a|s) nn.ReLU(), nn.Linear(16, self.action_space_size) ) self.qnet = nn.Sequential( # q_value_net nn.Linear(2, 16), # s -> q[s,a] nn.ReLU(), nn.Linear(16, self.action_space_size) ) self.value_optimizer = torch.optim.Adam(self.qnet.parameters(), lr=lr_critic) self.policy_optimizer = torch.optim.Adam(self.pnet.parameters(), lr=lr_actor) self.policy = np.zeros((self.state_space_size, self.action_space_size)) self.q_value = np.zeros((self.state_space_size, self.action_space_size)) def decode_state(self, state): ''' :param state: int :return: 归一化后的元组 ''' i = state // self.env.size j = state % self.env.size return torch.tensor((i / (self.env.size - 1), j / (self.env.size - 1)), dtype=torch.float32) def generate_action(self, state): ''' :param state: tuple :return: int,float ''' logits = self.pnet(state) action_probs = torch.softmax(logits, dim=0) # π(a|s,θ) action_dist = torch.distributions.Categorical(action_probs) # 按分布采样 action = action_dist.sample() log_prob = action_dist.log_prob(action) # In π(a|s,θ) 注意传入的是索引,会自动做log(action_probs[action_index]) return action.item(), log_prob def solve(self, num_episodes=200): for _ in range(num_episodes): state_int = self.env.reset() state = self.decode_state(state_int) done = False while not done: action, log_prob = self.generate_action(state) # a_t,s_t,In π(a_t|s_t,θ) next_state_int, reward, done = self.env.step(state_int, action) # s_t+1,r_t+1 next_state = self.decode_state(next_state_int) if not done: next_action, _ = self.generate_action(next_state) # a_t+1 else: next_action, action_prob = None, None # Critic (value update) qvalue = self.qnet(state)[action] # q(s_t,a_t) if done: td_target = torch.tensor(reward, dtype=torch.float32) else: with torch.no_grad(): # semi gradient qvalue_next = self.qnet(next_state)[next_action] # q(s_t+1,a_t+1) td_target = torch.tensor(reward, dtype=torch.float32) + self.gamma * qvalue_next delta = td_target - qvalue # TD error self.value_optimizer.zero_grad() critic_loss = 0.5 * delta.pow(2) critic_loss.backward() self.value_optimizer.step() # Actor (policy update) qvalue = qvalue.detach() # 避免梯度污染 self.policy_optimizer.zero_grad() actor_loss = -log_prob * qvalue actor_loss.backward() self.policy_optimizer.step() state_int = next_state_int state = next_state def get_policy_by_policy_net(self): for s in range(self.state_space_size): if s in self.env.terminal: self.policy[s,4]=1 break s_t = self.decode_state(s) logits = self.pnet(s_t) action_probs = torch.softmax(logits, dim=0) a=torch.argmax(action_probs) self.policy[s,a]=1 return self.policy def get_policy_by_state_value_net(self): for s in range(self.state_space_size): if s in self.env.terminal: self.policy[s,4]=1 break a = np.argmax(self.q_value[s]) self.policy[s, a] = 1 return self.policy def get_qvalues(self): for s in range(self.state_space_size): s_t = self.decode_state(s) logits = self.qnet(s_t).detach().numpy() # q(s,a)表示在状态s执行动作a后,未来所有折扣回报的期望值,不要取softmax然后取最大 self.q_value[s, :] = logits return self.q_value if __name__ == '__main__': env = GridWorldEnv( size=5, forbidden=[(1, 2), (3, 3)], terminal=[(4, 4)], r_boundary=-1, r_other=-0.04, r_terminal=1, r_forbidden=-1, r_stay=-0.1 ) # 注意samples要大一点,否则每个state被访问到的概率很小 vi = QAC(env=env) vi.solve(num_episodes=200) print("\n state value: ") print(vi.get_qvalues()) print("\n get policy by policy net:") drow_policy(vi.get_policy_by_policy_net(), env) print("\n get policy by state value net:") drow_policy(vi.get_policy_by_state_value_net(), env)

运行结果:

1. 表现一致(终点状态不是 . 是因为没有特殊处理,其他代码保持不变。由于表现一致的情况很少,因此不再继续展示特殊处理后的输出)

2. 表现不一致

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 0:10:15

Java 的节奏哲学:一门不追求“最快”,却极少“失控”的工程语言

在技术讨论中,“快”常常被当作最高追求: 启动要快、响应要快、开发要快、迭代要快。 但在真实工程世界里,很多系统并不是因为“慢”而失败,而是因为节奏失控。节奏失控意味着:负载变化无法预期性能波动难以解释系统状…

作者头像 李华
网站建设 2026/4/17 10:03:48

收藏!大模型项目别瞎做,这样做才拿得到Offer

在CSDN的大模型交流区和我的学习社群里,每天都能刷到类似的困惑:有人晒出自己搭建的第8个RAG系统Demo,简历里却写得像“流水账”;有人把LoRA微调、模型量化玩得炉火纯青,面试时被问“这个技术能帮公司省多少钱”却哑口…

作者头像 李华
网站建设 2026/4/23 11:39:00

Go语言中的切片

Go 语言中的切片(Slice)是一个非常核心的数据结构,它是对数组的抽象和封装,提供了更灵活、强大的序列处理能力。一. 切片的基本概念切片是一个动态数组,它由三个部分组成:指针:指向底层数组的起…

作者头像 李华
网站建设 2026/4/23 14:44:10

粒子数据结构示例

[1]计及网架重构分布式电源容量配置程序 粒子群算法 粒子群算法对配电网分布式电源容量配置 以IEEE33节点为例 以节点电压偏差最小,有功网损最小为优化目标,计及配电网网架重构,优化DG容量和开断支路 包含【参考文献,详细说明】电…

作者头像 李华
网站建设 2026/4/23 14:40:21

APOVMD自适应变分模态分解 通过变分模态分解模态分解的中心频率比值自适应的选择模态数和惩罚因子

APOVMD自适应变分模态分解 通过变分模态分解模态分解的中心频率比值自适应的选择模态数和惩罚因子,避免手动选参造成的过度分解或信息丢失问题 matlab代码,注释清楚;含参考文献 数据为excel数据,使用时替换数据集即可;…

作者头像 李华