Press "Enter" to skip to content

一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题

一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第1张一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第2张

上下文学习是一种最近的范式,其中一个大型语言模型(LLM)观察一个测试实例和一些训练示例作为其输入,并直接解码输出,而不对其参数进行任何更新。这种隐式训练与通常的训练相反,通常的训练会根据示例来改变权重。

一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第3张
来源: https://arxiv.org/pdf/2301.07067.pdf

那么为什么上下文学习会有益呢?你可以假设你有两个回归任务要建模,但唯一的限制是你只能使用一个模型来适应这两个任务。在这种情况下,上下文学习非常有用,因为它可以为每个任务学习回归算法,这意味着模型将为不同的输入集使用单独的适应回归。

“Transformers as Algorithms: Generalization and Implicit Model Selection in In-context Learning”这篇论文中,他们将上下文学习问题形式化为一个算法学习问题。他们使用transformer作为学习算法,在推理时通过训练来实现另一个目标算法。在这篇论文中,他们通过transformer探索了上下文学习的统计学方面,并进行了数值评估以验证理论预测。

在这项工作中,他们研究了两种情况,第一种情况是提示由一系列i.i.d(输入、标签)对组成,而第二种情况是一个动态系统的轨迹(下一个状态取决于前一个状态:xm+1 = f(xm) + noise)。

现在问题来了,我们如何训练这样的模型?

在ICL的训练阶段,T个任务与数据分布 {Dt}t=1T相关联。他们从对应分布中独立采样训练序列St。然后他们从序列St中选择一个子序列S和一个值x,对x进行预测。这就像元学习框架一样。预测之后,我们最小化损失。ICL训练背后的直觉可以解释为在寻找适应当前任务的最优算法。

接下来,为了获得ICL的泛化界限,他们从算法稳定性文献中借用了一些稳定性条件。在ICL中,提示中的训练示例影响到从那一点起算法的未来决策。因此,为了处理这些输入扰动,他们需要对输入施加一些条件。您可以阅读[论文]以获取更多细节。图7显示了对学习算法(这里是Transformer)稳定性进行实验评估的结果。

一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第4张
一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第5张
来源: https://arxiv.org/pdf/2301.07067.pdf

RMTL 是多任务学习中的风险(~错误)。从推导出的界限中得出的一个洞见是,通过增加样本量n或任务中的序列数量M,可以消除ICL的泛化误差。相同的结果也可以扩展到稳定的动态系统。

一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第6张
一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第7张
来源: https://arxiv.org/pdf/2301.07067.pdf
一种新的人工智能(AI)研究方法将基于提示的上下文学习作为一种从统计角度看待的算法学习问题 机器学习 第8张
来源: https://arxiv.org/pdf/2301.07067.pdf

现在让我们通过数值评估来验证这些界限。

所有实验都使用包含12层,8个注意力头和256维嵌入的GPT-2架构进行。实验是在回归和线性动力学上进行的。

  1. 线性回归:在图2(a)和2(b)中,上下文学习结果(红色)优于最小二乘结果(绿色),并与最佳岭/加权解(黑色虚线)完全对齐。这反过来证明了变压器通过学习任务先验能力的自动模型选择能力。
  2. 部分观测到的动态系统:在图2(c)和6中,结果表明上下文学习优于几乎所有订单H=1,2,3,4的最小二乘结果(其中H是滑动到输入状态序列上以生成输入到模型的窗口大小,类似于子序列长度)

总之,他们成功地展示了实验结果与理论预测相一致。对于未来的工作方向,有几个有趣的问题值得探索。

(1) 提出的界限是针对MTL风险的。如何控制各个任务的界限?

(2) 完全观察到的动态系统的相同结果是否可以扩展到更一般的动态系统,如强化学习?

(3) 从观察中得出结论,转移风险仅取决于MTL任务及其复杂性,并且与模型复杂性无关,因此研究这种归纳偏差以及变压器正在学习的算法类型将是有趣的。

Leave a Reply

Your email address will not be published. Required fields are marked *