Press "Enter" to skip to content

JAX中的深度强化学习的温和介绍

在不到一秒钟的时间内用DQN解决CartPole环境

Photo by Thomas Despeyroux on Unsplash

近来在强化学习(RL)方面的进展,例如Waymo的自动驾驶出租车或DeepMind的超级人类水平下国际象棋玩家,以深度学习组件,如神经网络梯度优化方法来补充经典RL

借鉴我之前的一篇文章介绍的基础和编码原则,我们将了解并实现深度Q网络DQN)和回放缓冲区来解决OpenAI的CartPole环境。我们将使用JAX在不到一秒钟的时间内完成所有这些!

关于JAX、矢量化环境和Q-learning的介绍,请参考本文的内容:

使用JAX对RL环境进行矢量化和并行化:以光速进行Q-learning⚡

学习如何对GridWorld环境进行矢量化,并在CPU上并行训练30个Q-learning代理,每个代理产生180万步…

towardsdatascience.com

我们选择的深度学习框架将是DeepMind的Haiku库,我最近在变压器的背景下介绍了它:

使用JAX和Haiku从头实现Transformer编码器 🤖

理解Transformers的基本构建模块。

towardsdatascience.com

本文将涵盖以下主题:

  • 为什么我们需要深度强化学习?
  • 深度Q网络的理论和实践
  • 回放缓冲区
  • CartPole环境转换为JAX
  • 使用JAX编写高效训练循环的方法

和往常一样,本文中提供的所有代码均可在GitHub上找到:

GitHub – RPegoud/jym: JAX实现的RL算法和矢量化环境

JAX实现的RL算法和矢量化环境 – GitHub – RPegoud/jym: JAX实现的RL…

github.com

为什么我们需要深度强化学习?

在之前的文章中,我们介绍了时差学习算法,特别是Q-learning

简而言之,Q-learning是一种离策略算法(目标策略不是用于决策的策略),它维护和更新一个Q表,一种将状态映射到相应动作值的明确映射

尽管Q-learning对于具有离散行动空间和受限观测空间的环境是一个实用的解决方案,但它在扩展到更复杂的环境时遇到困难。事实上,创建Q表需要定义动作和观测空间。

自动驾驶为例,观测空间由从摄像头源和其他传感器输入中得出的无限潜在配置组成。另一方面,动作空间包括广泛的方向盘位置和施加于制动器和油门的不同力度。

即使理论上我们可以将动作空间离散化,但是可能状态和动作的数量之多会导致在实际应用中产生一个不切实际的Q表

Photo by Kirill Tonkikh on Unsplash

因此,在大型和复杂的状态和动作空间中寻找最优动作需要强大的函数逼近算法,这正是神经网络所能提供的。在深度强化学习的情况下,神经网络被用作替代Q表,并提供了一种有效解决大状态空间引入的维度诅咒的方法。此外,我们不需要明确定义观测空间。

深度Q网络与经验回放缓冲区

DQN并行使用两种类型的神经网络,首先是“在线”网络,用于Q值预测决策。另一方面,“目标”网络用于创建稳定的Q目标,以通过损失函数评估在线网络的性能。

与Q-learning类似,DQN代理由两个函数定义:actupdate

行动

act函数根据由在线神经网络估计的Q值实施ε-greedy策略。换句话说,代理根据给定状态的最大预测Q值选择相应的动作,并具有以一定概率随机执行动作的设定。

您可能还记得Q-learning在每一步之后更新其Q表,但在深度学习中,通常使用批处理输入上的梯度下降来计算更新。

因此,DQN将经验(包含state,action,reward,next_state,done_flag的元组)存储在回放缓冲区中。为了训练网络,我们将从该缓冲区中采样一批经验,而不仅仅使用最后一次经验(有关详细信息,请参见回放缓冲区部分)。

DQN行动选择过程的可视化表示(作者制作)

这是DQN行动选择部分的JAX实现:

这段代码的唯一微妙之处在于model属性不包含内部参数,这在诸如PyTorch或TensorFlow等框架中通常是例行公事。

在这里,模型是一个代表我们架构的前向传递函数,但可变权重被外部存储并作为参数传递。这解释了为什么我们可以在传递self参数作为静态(模型作为其他类属性没有状态)

更新

update函数负责训练网络。它根据时间差异(TD)误差计算均方误差(MSE)损失:

深度 Q 网络中使用的均方误差

在这个损失函数中,θ表示在线网络的参数,而θ−表示目标网络的参数。目标网络的参数每隔 N 步设置为在线网络的参数,类似于一个检查点(N是一个超参数)。

这种参数分离(使用 θ 表示当前 Q 值和使用 θ− 表示目标 Q 值)对于稳定训练是至关重要的。

如果使用相同的参数,就好像对准移动目标一样,因为网络的更新会立即改变目标值。通过定期更新 θ−(即在一定数量的步骤中冻结这些参数),我们确保 Q 目标稳定,同时在线网络继续学习。

最后,(1-done) 项调整目标以适应终端状态。确实,在一个回合结束时(即 ‘done’ 等于 1),没有下一个状态。因此,下一个状态的 Q 值设置为 0。

DQN 参数更新过程的可视化表示(作者绘制)

为 DQN 实现 update 函数稍微复杂一些,让我们分解一下:

  • 首先,_loss_fn 函数为单个经验实现了之前描述的平方误差。
  • 然后,_batch_loss_fn 作为 _loss_fn 的包装器,并使用 vmap 将损失函数应用于一批经验。然后返回该批次的平均误差。
  • 最后,update 充当损失函数的最终层,根据在线网络参数、目标网络参数和一批经验来计算其梯度。然后使用 Optax(一种常用于优化的 JAX 库)执行优化器步骤并更新在线参数。

请注意,与回放缓冲区类似,模型和优化器都是修改外部状态纯函数。下面这行代码很好地说明了这个原则:

updates, optimizer_state = optimizer.update(grads, optimizer_state)

这也解释了为什么我们可以对在线网络和目标网络使用同一个模型,因为参数被存储并在外部进行更新。

# 目标网络预测self.model.apply(target_net_params, None, state)# 在线网络预测self.model.apply(online_net_params, None, state)

为了上下文,本文中使用的模型是一个多层感知机,定义如下:

N_ACTIONS = 2NEURONS_PER_LAYER = [64, 64, 64, N_ACTIONS]online_key, target_key = vmap(random.PRNGKey)(jnp.arange(2) + RANDOM_SEED)@hk.transformdef model(x):    # 简单的多层感知机    mlp = hk.nets.MLP(output_sizes=NEURONS_PER_LAYER)    return mlp(x)online_net_params = model.init(online_key, jnp.zeros((STATE_SHAPE,)))target_net_params = model.init(target_key, jnp.zeros((STATE_SHAPE,)))prediction = model.apply(online_net_params, None, state)

重放缓冲区

现在让我们退后一步,更详细地看看重放缓冲区。在强化学习中,它们被广泛用于各种原因:

  • 泛化:通过从重放缓冲区中抽样,我们打破了连续经验之间的关联性,通过混合它们的顺序来避免对特定经验序列的过拟合。
  • 多样性:由于采样不仅仅限于最近的经验,我们通常观察到更新的方差较低,并且可以防止对最新经验的过拟合。
  • 提高样本效率:每个经验可以从缓冲区中多次采样,使模型能够从单个经验中学到更多。

最后,我们可以为重放缓冲区使用几种采样方案:

  • 均匀采样:以均匀随机方式对经验进行采样。这种采样类型简单易行,并允许模型独立于它们被收集的时间步骤从经验中学习。
  • 优先采样:此类包括不同的算法,例如优先经验重放(”PER”,Schaul et al. 2015)或梯度经验重放(”GER”,Lahire et al., 2022)。这些方法试图根据与其“学习潜力”相关的某些度量来优先选择经验(对于PER来说是TD误差的幅度,对于GER来说是经验梯度的范数)。

为了简单起见,本文将实现一个均匀重放缓冲区。然而,我计划在未来详细介绍优先采样。

如约定,均匀重放缓冲区的实现相当简单,但与使用JAX和函数式编程相关的一些复杂性。我们必须始终使用纯函数,这些函数没有副作用。换句话说,我们不允许将缓冲区定义为带有内部状态变量的类实例。

相反,我们初始化一个buffer_state字典,将键映射为具有预定义形状的空数组,因为JAX在将代码jit-编译为XLA时需要常量大小的数组。

buffer_state = {    "states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),    "actions": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),    "rewards": jnp.empty((BUFFER_SIZE,), dtype=jnp.int32),    "next_states": jnp.empty((BUFFER_SIZE, STATE_SHAPE), dtype=jnp.float32),    "dones": jnp.empty((BUFFER_SIZE,), dtype=jnp.bool_),}

我们将使用UniformReplayBuffer类与缓冲区状态进行交互。此类有两种方法:

  • add:取消封装一个经验元组,并将其组件映射到特定索引。idx = idx%self.buffer_size确保当缓冲区已满时,添加新的经验会覆盖旧的经验。
  • sample:从均匀随机分布中随机抽取索引序列。序列长度由batch_size设置,索引的范围是[0, current_buffer_size-1]。这确保我们在缓冲区尚未填满时不会抽样空数组。最后,我们使用JAX的vmap结合tree_map返回一批经验。

CartPole环境翻译为JAX

现在我们的DQN agent已经准备好进行训练,我们将使用与之前的一篇文章介绍的相同框架快速实现一个矢量化的CartPole环境。CartPole是一个具有较大的连续观测空间的控制环境,这使得它对我们的DQN进行测试非常合适。

CartPole环境的可视化表示(来源和文档:OpenAI Gymnasium,MIT许可证)

这个过程非常简单,我们在重用OpenAI的Gymnasium实现时确保使用JAX数组和lax控制流代替Python或Numpy的替代方案,例如:

# Python 实现
force = self.force_mag if action == 1 else -self.force_mag
# Jax 实现
force = lax.select(jnp.all(action) == 1, self.force_mag, -self.force_mag)
# Python
costheta, sintheta = math.cos(theta), math.sin(theta)
# Jax
cos_theta, sin_theta = jnp.cos(theta), jnp.sin(theta)
# Python
if not terminated:  
    reward = 1.0
else:
    reward = 0.0
# Jax
reward = jnp.float32(jnp.invert(done))

为了简洁起见,完整的环境代码在这里可用:

jym/src/envs/control/cartpole.py at main · RPegoud/jym

JAX实现强化学习算法和向量化环境 – jym/src/envs/control/cartpole.py at main ·…

github.com

用JAX编写高效的训练循环的方法

我们DQN实现的最后一部分是训练循环(也称为模拟运行)。正如前面的文章中提到的,为了充分利用JAX的速度,我们必须遵守特定的格式。

模拟运行函数可能一开始看起来令人生畏,但它的大部分复杂性都是纯粹的语法,因为我们已经涵盖了大部分构建块。以下是伪代码的步骤:

1. 初始化:  * 创建存储状态、动作、奖励和完成标志的空数组以及带有虚拟数组的网络和优化器  * 将所有初始化的对象包装在一个 val 元组中2. 训练循环(重复 i 次):  * 解包 val 元组  * (可选)使用衰减函数衰减 epsilon  * 根据状态和模型参数采取行动  * 执行环境步骤并观察下一个状态、奖励和完成标志  * 创建经验元组(状态、动作、奖励、新状态、完成标志)并将其添加到回放缓冲区  * 根据当前缓冲区大小采样一批经验(即只从具有非零值的经验中采样)  * 使用经验批量更新模型参数  * 每隔 N 步,更新目标网络的权重(设置目标参数 = 在线参数)  * 存储当前集数的经验值并返回更新后的 `val` 元组

现在我们可以运行 DQN 进行20,000 步并观察其表现。在大约 45 个集数之后,代理成功地达到了令人满意的表现,能够持续平衡杆子超过 100 步。

绿色的条形图表示代理成功地平衡了超过 200 步,从而解决了环境问题。值得注意的是,代理在第 51 个集数时创下了自己的393 步的纪录。

DQN 的性能报告(作者制作)

20,000 个训练步骤仅用时一秒多一点,速度为每秒 15,807 步(在一个单 CPU 上)!

这些表现显示出 JAX 的令人印象深刻的扩展能力,使从业人员能够以最小的硬件要求运行大规模并行化实验。

Running for 20,000 iterations: 100%|██████████| 20000/20000 [00:01<00:00, 15807.81it/s]

我们将在未来的文章中更详细地探讨并行化模拟运行过程,以进行具有统计显著性的实验和超参数搜索

同时,欢迎使用这个笔记本来重现实验并尝试调整超参数:

jym/notebooks/control/cartpole/dqn_cartpole.ipynb 在主分支上的 RPegoud/jym

JAX 实现的 RL 算法和矢量化环境 – jym/notebooks/control/cartpole/dqn_cartpole.ipynb at…

github.com

结论

一如既往,感谢您阅读到这里!我希望本文为您提供了关于 JAX 中深度强化学习的良好介绍。如果您对本文内容有任何问题或反馈,请务必让我知道,我很乐意与您稍作交流;)

下次再见 👋

致谢:

Leave a Reply

Your email address will not be published. Required fields are marked *