在这篇文章中,我们将学习如何将RL环境向量化,并在CPU上并行训练30个Q-learning代理,每秒1.8百万次迭代。
在先前的故事中,我们介绍了时差学习(Temporal-Difference Learning),特别是在GridWorld背景下的Q-learning。
时差学习和探索的重要性:一份插图指南
对动态网格世界上的无模型(Q-learning)和基于模型(Dyna-Q和Dyna-Q+)TD方法的比较。
towardsdatascience.com
虽然这个实现达到了展示这些算法性能和探索机制差异的目的,但它非常缓慢。
事实上,环境和代理主要使用了Numpy进行编码,尽管Numpy在RL中并不是标准,但它使代码易于理解和调试。
在本文中,我们将看到如何通过向量化环境和无缝并行化数十个代理的训练来扩大RL实验。特别是,本文涵盖以下内容:
- JAX的基础知识和对RL有用的功能
- 向量化环境及其如此快的原因
- 在JAX中实现环境、策略和Q-learning代理
- 单个代理的训练
- 如何并行化代理训练,以及它有多容易!
本文中展示的所有代码都可以在GitHub上找到:
GitHub – RPegoud/jax_rl:RL算法和向量化环境的JAX实现
RL算法和向量化环境的JAX实现-GitHub-RPegoud/jax_rl:RL算法和向量化环境的JAX实现
github.com
JAX基础知识
JAX是谷歌开发的另一个Python深度学习框架,被诸如DeepMind等公司广泛使用。
“JAX是在高性能数值计算方面集合在一起的Autograd(自动微分)和XLA(加速线性代数,即TensorFlow编译器)。” — 官方文档
与大多数Python开发者习惯的方式相反,JAX不采用面向对象编程(OOP)的范例,而是采用了函数式编程(FP)[1]。
简单地说,它依赖于纯函数(确定性和没有副作用)以及不可变数据结构(而不是在原地改变数据,使用所需的修改创建新的数据结构)作为主要构建块。因此,FP鼓励更加功能型和数学化的编程方法,使其非常适合数值计算和机器学习等任务。
让我们通过伪代码来说明这两个范例之间的差异,这是一个Q更新函数:
- 面向对象的方法依赖于包含各种状态变量(如Q值)的类实例。更新函数被定义为一个类方法,用于更新实例的内部状态。
- 函数式编程方法依赖于纯函数。事实上,这个Q更新是确定性的,因为Q值作为参数传递。因此,对该函数的任何调用,只要输入相同,输出就会相同,而类方法的输出可能取决于实例的内部状态。另外,数据结构(如数组)在全局范围内被定义和修改。
因此,JAX提供了各种函数修饰符,在RL的环境中特别有用:
- vmap (向量化映射):允许对单个样本进行操作的函数在批处理上应用。例如,如果env.step()是在单个环境中执行一步的函数,vmap(env.step)()就是在多个环境中执行一步的函数。换句话说,vmap为函数添加了一个批处理维度。
- jit (即时编译):允许JAX执行“JAX Python函数的即时编译”,使其与XLA兼容。使用jit可以编译函数并提供显著的速度提升(换取首次编译函数时的额外开销)。
- pmap (并行映射):与vmap类似,pmap实现了简便的并行化。但是,pmap不是在函数中添加一个批处理维度,而是复制函数并在多个XLA设备上执行。注意:应用pmap时,jit也会自动应用。
现在我们已经了解了JAX的基本知识,我们将看到如何通过向量化环境来实现大幅度的加速。
向量化环境:
首先,什么是向量化环境,向量化解决了哪些问题?
在大多数情况下,RL实验由CPU-GPU数据传输减慢。深度学习RL算法(如PPO)使用神经网络来近似策略。
在深度学习中,神经网络在训练和推断时通常使用GPU。然而,在大多数情况下,环境在CPU上运行(即使在同时使用多个环境的情况下也是如此)。
这意味着通常的强化学习循环,即通过策略(神经网络)选择动作并从环境中接收观测和奖励,需要GPU和CPU之间频繁的数据传输,这对性能产生了不利影响。
此外,如果不使用“jitting”等框架(如PyTorch),可能会存在一些开销,因为GPU可能必须等待Python从CPU发送回观测和奖励。
另一方面,JAX使我们能够轻松地在GPU上运行批量环境,消除了GPU和CPU之间数据传输引起的阻碍。
此外,由于jit将我们的JAX代码编译为XLA,执行过程不再(或至少减少)受Python的低效率的影响。
有关更多详细信息和令人兴奋的元学习强化学习研究应用程序,我强烈推荐阅读Chris Lu的博客文章。
环境、代理和策略实现:
让我们来看一下我们的强化学习实验的不同部分的实现。以下是我们需要的基本功能的高级概述:
环境
此实现遵循Nikolaj Goodger在他关于在JAX中编写环境的精彩文章中提供的方案。
在JAX中编写强化学习环境
如何以每秒1.25亿步运行CartPole
VoAGI.com
让我们首先从环境及其方法的高级视图开始。这是在JAX中实现环境的一般计划:
让我们更详细地了解类方法(作为提醒,以“_”开头的函数是私有函数,不应在类的作用域之外调用):
- _get_obs:此方法将环境状态转换为代理的观测。在部分可观测或随机环境中,处理应用于状态的函数将在此处进行。
- _reset:由于我们将并行运行多个代理,因此需要一种方法在每个回合结束时对其进行单独重置。
- _reset_if_done:此方法将在每一步调用,并在“done”标志设置为True时触发_reset。
- reset:此方法在实验开始时被调用,以获取每个代理的初始状态以及相关的随机密钥
- step:给定状态和动作,环境返回一个观测(新状态)、奖励和更新后的“done”标志。
实际上,GridWorld环境的通用实现如下:
请注意,正如之前提到的,所有类方法都遵循函数式编程范式。事实上,我们从不更新类实例的内部状态。此外,类属性都是在实例化后不会被修改的常量。
让我们仔细看一下:
- __init__:在我们的GridWorld环境中,可用的动作为[0, 1, 2, 3]。这些动作使用self.movements被转换成一个二维数组,并添加到step函数的状态中。
- _get_obs:我们的环境是确定性和完全可观察的,因此代理直接接收状态而不是处理后的观察。
- _reset_if_done:env_state参数对应于(state, key)元组,其中key是一个jax.random.PRNGKey。如果done标志设置为True,则此函数简单地返回初始状态。然而,我们不能在JAX jitted函数中使用常规的Python控制流。使用jax.lax.cond,我们本质上获得了一个等同于以下表达式的表达式:
def cond(condition, true_fun, false_fun, operand): if condition: # 如果done标志为True return true_fun(operand) # 返回self._reset(key) else: return false_fun(operand) # 返回env_state
- step:我们将动作转换为移动,并将其添加到当前状态中(jax.numpy.clip确保代理保持在网格内)。然后在检查是否需要重置环境之前更新env_state元组。由于step函数在训练过程中经常使用,通过jit编译它可以显著提高性能。@partial(jit, static_argnums=(0, )装饰器表示类方法的“self”参数应被视为静态的。换句话说,类属性是常量,在连续调用step函数时不会改变。
Q-学习代理
Q-学习代理由update函数定义,以及静态的学习率和折扣因子。
再次注意,在jit编译update函数时,我们将“self”参数作为静态参数传递。此外,请注意,q_values矩阵是使用set()原地修改的,其值不作为类属性存储。
Epsilon-Greedy策略
最后,此实验中使用的策略是标准的Epsilon-Greedy策略。一个重要的细节是它使用随机决策,这意味着如果最大的Q值不唯一,动作将从最大的Q值中均匀采样(使用argmax总是返回第一个具有最大Q值的动作)。如果将Q值初始化为一个由零组成的矩阵,这一点尤为重要,因为动作0(向右移动)将总是被选中。
否则,该策略可以通过以下代码摘要:
action = lax.cond( explore, # 如果p < epsilon _random_action_fn, # 根据密钥选择随机动作 _greedy_action_fn, # 根据Q值选择贪婪动作 operand=subkey, # 将subkey作为上述函数的参数 )return action, subkey
请注意,在JAX中使用密钥(例如,这里我们采样了一个随机浮点数并使用random.choice)后,通常的做法是在此后分割密钥(即“进入新的随机状态”,更多细节可以参考此处)。
单代理训练循环:
现在,我们已经拥有了所有所需的组件,让我们训练一个单代理。
下面是一个Pythonic训练循环,正如您所看到的,我们基本上是使用策略选择一个动作,在环境中执行一步,并更新Q值,直到一个episode结束。然后我们重复该过程进行N个episode。正如我们将在一分钟内看到的,这种训练代理的方式相当低效,但它以可读的方式概括了算法的关键步骤:
在单个CPU上,我们以881个单元和21,680个步骤的速度,在11秒内完成了10,000个情节。
100%|██████████| 10000/10000 [00:11<00:00, 881.86it/s]总步数:238,488每秒步数:21,680
现在,让我们使用JAX语法复制相同的训练循环。以下是rollout函数的高级描述:
总结一下,rollout函数:
- 使用jax.numpy.zeros将观察、奖励和完成标志初始化为空数组,并且数组的维度等于时间步数的数量。初始化的Q值是形状为空矩阵,形状为[timesteps+1, grid_dimension_x, grid_dimension_y, n_actions]。
- 调用env.reset()函数获取初始状态
- 使用jax.lax.fori_loop()函数调用fori_body()函数N次,其中N是时间步参数
- fori_body()函数的行为与先前的Python循环类似。选择动作、执行步骤和计算Q更新后,我们就可以就地更新obs、rewards、done和q_values数组,Q更新目标为时间步t+1)。
这种额外的复杂性导致了85倍的加速,我们现在以大约每秒1.83百万步的速度训练我们的代理程序。请注意,这里的训练是在单个CPU上进行的,因为环境比较简单。
然而,端到端的向量化在应用于复杂环境和从多个GPU中受益的算法时效果更好(Chris Lu的文章中报告了干净RL PPO的PyTorch实现和JAX再现之间的4000倍加速)。
100%|██████████| 1000000/1000000 [00:00<00:00, 1837563.94it/s]总步数:1,000,000每秒步数:1,837,563
训练代理程序后,我们绘制了每个单元格(即状态)的最大Q值的热度图,并观察到它已经有效地学会从初始状态(右下角)移动到目标(左上角)。
GridWorld每个单元格的最大Q值的热度图(作者制作)
并行代理训练循环:
正如承诺的那样,既然我们已经编写了训练单个代理程序所需的函数,那么在批量环境上并行训练多个代理程序将没有太多工作 left!
多亏了vmap,我们可以快速将我们之前的函数转换为批量数据上的工作方式。我们只需要指定期望的输入和输出形状,例如对于env.step:
- in_axes = ((0,0), 0) 表示输入形状,由env_state元组(维度(0, 0))和观察(维度0)组成。
- out_axes = ((0, 0), 0, 0, 0) 表示输出形状,输出为((env_state),obs,reward,done)。
- 现在,我们可以在一组env_states和actions上调用v_step,并获得一组处理过的env_states、观察、奖励和完成标志的数组。
- 请注意,我们还为了性能将所有批处理函数都进行了jit(可以说在训练函数中jit env.reset()是不必要的,因为它只被调用一次)。
我们需要做的最后一项调整是给我们的数组添加一个批次维度来处理每个代理的数据。
通过这样做,我们可以获得一个可以训练多个代理的并行函数,与单个代理函数相比,只需要进行最小的调整:
我们可以在这个版本的训练函数中获得类似的性能:
100%|██████████| 100000/100000 [00:02<00:00, 49036.11it/s]总步数:100,000 * 30 = 3,000,000每秒步数:49,036 * 30 = 1,471,080
就是这样!感谢您阅读到这里,希望本文对在 JAX 中实现向量化环境提供了有帮助的介绍。
如果您喜欢阅读,可以考虑分享本文和给我的 GitHub 仓库点个赞,感谢您的支持! 🙏
GitHub – RPegoud/jax_rl: JAX 实现强化学习算法和向量化环境
JAX 实现强化学习算法和向量化环境 – GitHub – RPegoud/jax_rl: JAX 实现强化学习算法和向量化环境
github.com
最后,对于那些对深入了解感兴趣的人,这里是一些有用的资源清单,这些资源帮助我开始学习 JAX 并撰写本文:
一个精选的 JAX 文章和资源清单:
[1] Coderized,(函数式编程)最纯粹的编码风格,几乎没有错误,YouTube
[2] Aleksa Gordić,JAX 从零到英雄的 YouTube 播放列表(2022),The AI Epiphany
[3] Nikolaj Goodger,在 JAX 中编写 RL 环境(2021)
[4] Chris Lu,通过 PureJaxRL 实现 4000 倍加速和元进化发现(2023),牛津大学,Foerster AI 研究实验室
[5] Nicholas Vadivelu,Awesome-JAX(2020),JAX 的库、项目和资源清单
[6] JAX 官方文档,使用 PyTorch 数据加载训练简单神经网络