Press "Enter" to skip to content

PyTorch的Nesterov动量实现有问题吗?

动量可以帮助SGD更高效地遍历复杂的损失函数空间。照片由Maxim Berg在Unsplash上提供。

介绍

如果您仔细阅读PyTorch对SGD的文档,您会发现他们对Nesterov动量的实现与原始论文中的公式有一些差异。最明显的是,PyTorch的实现在当前参数处评估梯度,而Nesterov动量的整个目的是评估平移参数处的梯度。不幸的是,关于这些差异的讨论在互联网上很少见。在本文中,我们将检查和解释PyTorch的实现与Nesterov动量原始公式之间的差异。最终,我们将看到PyTorch的实现并没有错误,而是一种近似,并推测他们实现的好处。

公式

原始论文使用以下更新规则描述Nesterov动量:

其中v_{t+1}和θ_{t+1}分别表示时间t的速度向量和模型参数,μ是动量因子,ε是学习率。PyTorch的SGD文档中的注释指出他们使用以下更新规则:

其中g_{t+1}表示用于计算v_{t+1}的梯度。我们可以展开θ_{t+1}的更新规则:

从中可以推断出:

更新规则变为:

这些是PyTorch在理论上使用的更新规则。我之前提到过,PyTorch实际上在当前参数处评估梯度,而不是平移参数处。通过查看PyTorch SGD文档中的算法描述,我们将在后面进一步调查这一点。

请注意,对于原始(1, 2)和PyTorch(3, 4)公式,如果v_0 = 0,则θ的第一个更新为:

尽管PyTorch的SGD文档注释指出算法将动量缓冲区初始化为第一步的梯度,但我们将在后面显示这意味着v_0 = 0。

初步差异

从原始(1, 2)到PyTorch(3, 4)公式的转换有两个明显的差异:

  1. 学习率移到了v_{t+1}的外部。
  2. 在v_{t+1}的更新规则中,添加了涉及梯度的项,而在θ_{t+1}的更新规则中,减去了涉及速度向量的项。梯度项内的符号差异只是这一差异的结果,如前一节所示。

为了理解这些差异,让我们首先展开更新规则。正如这里所暗示的,如果我们考虑学习率调度,第一个差异的效果在这里更为明显。因此,我们考虑一个更新规则的推广,其中ε不再是固定的,而是可以随时间变化,并将ε_t表示为时间步t的学习率。为了简洁起见,令:

假设v_0 = 0,原始公式变为:

而PyTorch公式变为:

在原始公式(6)中,如果学习率在时间t发生变化,那么只有i = t时求和中的项的大小会受到影响,而所有其他项的大小保持不变。因此,学习率变化的直接影响相当有限,我们必须等待学习率变化在后续时间步骤中“渗透”以对整体步长产生更强的影响。相比之下,在PyTorch公式(7)中,如果学习率在时间t发生变化,那么整个步长的大小会立即受到影响。

对于v_0 = 0,从展开的规则中可以清楚地看出第二个差异最终没有影响;在任何一种公式中,步长都是从当前参数中减去的梯度的折扣总和。

主要差异

忽略权重衰减和阻尼,通过分析PyTorch文档中的SGD算法,我们可以看到实现的更新规则如下:

其中θ’_{t+1} 是时间 t 的模型参数

我们将方程3和4称为PyTorch的“笔记”表示法,方程8和9称为PyTorch的“实际”表示法。我们之所以将θ和θ’区分开来,是有一个即将明显的原因的。与笔记表示法最明显的区别是,梯度是在当前参数而不是偏移参数处进行评估的。仅从这一点来看,算法实现的更新规则似乎不是Nesterov动量的正确实现。

现在我们将研究PyTorch算法如何最终近似Nesterov动量。旧版本PyTorch的推导可以在这里找到,该推导引用了Ivo Danihelka在GitHub问题中提到的内容。当前版本PyTorch的推导可以在这里找到,这是从之前推导中相对简单的调整。我们在这里提供了这些(重新推导的)推导的LaTeX渲染,以方便读者阅读。实际表示法通过简单的变量更改得出。具体来说,我们令:

很明显,v_{t+1}(3)的笔记更新规则经变量更改后等同于v_{t+1}(8)的实际更新规则。我们现在想要推导一个以θ’_t为基础的θ’_{t+1}的更新规则:

这正是我们在PyTorch中看到的更新规则(9)。从高层次来看,PyTorch的实现假设当前参数θ’_t已经是“实际”参数θ_t的偏移版本。因此,在每个时间步骤中,“实际”参数θ_t与当前参数θ’_t的关系为:

然而,从源代码中可以看出,PyTorch的SGD实现在算法结束时没有进行任何校正以获得最终的“实际”参数,因此最终输出在技术上是“实际”参数的近似。

最后,我们现在证明v_0必须为0:

此外,我们可以确认,“实际”参数的第一次更新与原始公式中v_0 = 0时进行的第一次更新是相同的:

我们可以看到,这等同于方程5。

实现公式的好处

当然,最重要的问题是:为什么PyTorch要将Nesterov动量从方程3和4重新表述为方程8和9?一个可能的解释是,重新表述可能在所需的算术操作数量上提供一些节省。为了评估这种可能的解释,让我们统计一下算术操作的数量。对于笔记表示法(3, 4),我们有:

在这里,总共有七个操作。对于实际表示法(8, 9),我们有:

在这里,总共有六个操作。PyTorch实现中的第二个梯度只使用了第一个梯度计算的保存结果,因此每个时间步骤只执行一次梯度计算。因此,一个明显的好处是,PyTorch的实现在每个步骤中减少了一个额外的乘法运算。

结论

总结:

  1. PyTorch的SGD文档注释(3, 4)中的更新规则与原始的Nesterov动量更新规则(1, 2)在学习率的位置上有所不同。这使得学习率调度对整体步长产生立即影响,而原始的公式会使学习率的变化在后续时间步骤中逐渐传播。
  2. PyTorch的SGD算法实现的更新规则(8, 9)是对文档注释(3, 4)中的更新规则的近似,经过简单的变量更改。尽管每个时间步骤中,“实际”参数可以轻松地从当前参数中恢复,但PyTorch的实现在算法结束时没有进行任何校正,因此最终参数在技术上仍然是“实际”最终参数的近似。
  3. PyTorch实现的一个明显好处是,在每个时间步骤中避免了一个额外的乘法运算。

参考文献

  1. “SGD.” SGD — PyTorch 2.0 文档, pytorch.org/docs/stable/generated/torch.optim.SGD.html. 访问日期:2023年9月2日。
  2. Sutskever, Ilya, et al. “深度学习中初始化和动量的重要性。” 机器学习国际会议。PMLR,2013年。
  3. Danihelka, Ivo. “简化 Nesterov 动量。” 2012年8月25日。
  4. Chintala, Soumith. “nesterov 动量在 sgd 中的错误· Issue #27 · torch/optim。” GitHub,2014年10月13日,github.com/torch/optim/issues/27。
  5. Gross, Sam. “在文档中添加有关优化器中动量公式的说明· Issue #1099 · pytorch/pytorch。” GitHub,2017年3月25日,github.com/pytorch/pytorch/issues/1099#issuecomment-289190614。
  6. Zhao, Yilong. “修复 Nesterov 动量错误· Issue #5920 · pytorch/pytorch。” GitHub,2018年3月21日,https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908。
Leave a Reply

Your email address will not be published. Required fields are marked *