Press "Enter" to skip to content

利用预训练的语言模型检查点来构建编码器-解码器模型

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第1张

基于Transformer的编码器-解码器模型最初在Vaswani等人(2017)的论文中提出,并最近引起了广泛的关注,例如Lewis等人(2019),Raffel等人(2019),Zhang等人(2020),Zaheer等人(2020),Yan等人(2020)。

与BERT和GPT2类似,大规模预训练的编码器-解码器模型已经显示出在各种序列到序列任务上显著提升性能(Lewis等人,2019;Raffel等人,2019)。然而,由于预训练编码器-解码器模型所需的巨大计算成本,这类模型的开发主要局限于大型公司和研究机构。

在《利用预训练检查点进行序列生成任务》(2020)一文中,Sascha Rothe、Shashi Narayan和Aliaksei Severyn使用预训练的编码器和/或解码器检查点(如BERT、GPT2)初始化编码器-解码器模型,跳过了昂贵的预训练过程。作者表明,这种热启动的编码器-解码器模型在训练成本的一小部分情况下,能够产生与T5和Pegasus等大规模预训练编码器-解码器模型相竞争的结果,适用于多个序列到序列任务。

在本笔记本中,我们将详细解释如何热启动编码器-解码器模型,并根据Rothe等人(2020)提供实用提示,最后通过一个完整的代码示例展示如何使用🤗Transformers来热启动编码器-解码器模型。

本笔记本分为4个部分:

  • 介绍 – 简要介绍NLP中的预训练语言模型以及热启动编码器-解码器模型的需求。
  • 热启动编码器-解码器模型(理论) – 对编码器-解码器模型如何进行热启动进行说明。
  • 热启动编码器-解码器模型(分析) – 《利用预训练检查点进行序列生成任务》的总结
      • 哪些模型组合对于热启动编码器-解码器模型有效?它在不同任务中有何不同?
  • 使用🤗Transformers热启动编码器-解码器模型(实践) – 完整的代码示例,详细展示如何使用EncoderDecoderModel框架来热启动基于Transformer的编码器-解码器模型。

强烈推荐(可能甚至是必须的)阅读有关基于Transformer的编码器-解码器模型的博客文章。

让我们从对热启动编码器-解码器模型的背景介绍开始。

介绍

最近,预训练语言模型1 {}^1 1在自然语言处理(NLP)领域引起了革命。

最初的预训练语言模型基于循环神经网络(RNN),由Dai等人(2015)提出。Dai等人展示了在未标记数据上预训练基于RNN的模型,然后在特定任务上进行微调2 {}^2 2,其结果比直接在该任务上随机初始化模型进行训练要好。然而,直到2018年,预训练语言模型才在NLP领域得到广泛接受。Peters等人的ELMO和Howard等人的ULMFIT是第一批显著改进自然语言理解(NLU)任务的最新技术的预训练语言模型。几个月后,OpenAI和Google发布了基于Transformer的预训练语言模型,分别称为Radford等人的GPT和Devlin等人的BERT。相比于RNN,Transformer的语言模型改进了效率,使GPT2和BERT能够在大量未标记的文本数据上进行预训练。一旦预训练完成,BERT和GPT只需进行很少的微调,就能在超过十几个NLU任务上取得破纪录的结果3 {}^3 3。

预训练语言模型将任务无关的知识有效地转化为任务特定的知识,成为NLU的重要推动因素。以前,工程师和研究人员必须从头开始训练语言模型,而现在大规模预训练语言模型的公开检查点可以以较低的成本和时间进行微调。这可以为行业节省数百万美元,并为研究提供更快的原型设计和更好的基准。

预训练语言模型在NLU任务上取得了新的性能水平,越来越多的研究都建立在利用这种预训练语言模型来改进NLU系统的基础上。然而,独立的BERT和GPT模型在序列到序列的任务(例如文本摘要、机器翻译、句子改写等)上的表现较差。

序列到序列的任务被定义为从输入序列X1:n到输出序列Y1:m的映射,其中输出序列的长度m是先验未知的。因此,序列到序列模型应该定义在给定输入序列X1:n的条件下输出序列Y1:m的条件概率分布:

pθmodel(Y1:m|X1:n)。

不失一般性地,输入的n个词的序列可以用向量序列X1:n = x1, …, xn来表示,输出的m个词的序列可以用Y1:m = y1, …, ym来表示。

让我们看看BERT和GPT2如何适应序列到序列的任务。

BERT

BERT是一个仅包含编码器的模型,它将输入序列X1:n映射到上下文编码的序列X̄1:n

fθBERT: X1:n → X̄1:n

然后,BERT的上下文编码的序列X̄1:n可以进一步通过一个分类层进行处理,用于NLU分类任务,例如情感分析、自然语言推理等。为此,分类层(通常是一个汇聚层后面跟着一个前馈层)作为BERT的最后一层添加到上下文编码的序列X̄1:n上,将其映射到一个类别c:

fθp,c: X̄1:n → c。

已经证明,在预训练的BERT模型θBERT之上添加一个汇聚和分类层(定义为θp,c),并对整个模型{θp,c,θBERT}进行微调,可以在各种NLU任务上取得最先进的性能,参见Devlin等人的BERT。

让我们可视化BERT。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第2张

BERT模型显示为灰色。模型堆叠多个BERT块,每个块由双向自注意力层(显示在红框的下部)和两个前馈层(显示在红框的上部)组成。

每个BERT块利用双向自注意力对输入序列x’1,…,x’n(显示为浅灰色)进行处理,得到更“精细”的上下文输出序列x”1,…,x”n(显示为稍深灰色)4 {}^4 4。最后一个BERT块的上下文输出序列,即X̄1:n,可以通过添加一个任务特定的分类层(显示为橙色)来映射到单个输出类别c,如上所述。

仅编码器模型只能将输入序列映射到已知输出长度。总之,输出维度不依赖于输入序列,这使得仅编码器模型在序列到序列任务中具有不利和不实用的特点。

对于所有仅编码器模型而言,BERT的架构与基于Transformer的编码器-解码器模型的编码器部分的架构完全相同,如“编码器”部分所示。

GPT2

GPT2是一个仅解码器模型,它利用单向(即“因果关系”)自注意力来定义从输入序列Y0:m-1(显示为1)到“下一个单词”逻辑向量序列L1:m的映射:

fθGPT2:Y0:m-1→L1:m。

通过对逻辑向量L1:m进行softmax操作,模型可以定义单词序列Y1:m的概率分布。确切地说,单词序列Y1:m的概率分布可以分解为m-1个条件的“下一个单词”分布:

pθGPT2(Y1:m) = ∏i=1m pθGPT2(yi∣Y0:i-1)。

pθGPT2(yi∣Y0:i-1)在此处表示给定所有先前单词y0,…,yi-1的情况下下一个单词yi的概率分布3 {}^3 3,并且被定义为应用于逻辑向量li的softmax操作。总结起来,以下等式成立。

p θ gpt2 ( y i ∣ Y 0 : i − 1 ) = Softmax ( l i ) = Softmax ( f θ GPT2 ( Y 0 : i − 1 ) ) . p_{\theta_{\text{gpt2}}}(\mathbf{y}_i | \mathbf{Y}_{0:i-1}) = \textbf{Softmax}(\mathbf{l}_i) = \textbf{Softmax}(f_{\theta_{\text{GPT2}}}(\mathbf{Y}_{0: i – 1})). p θ gpt2 ​ ​ ( y i ​ ∣ Y 0 : i − 1 ​ ) = Softmax ( l i ​ ) = Softmax ( f θ GPT2 ​ ​ ( Y 0 : i − 1 ​ ) ) .

更多细节请参考编码器-解码器博文中的解码器部分。

现在我们来可视化一下 GPT2。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第3张

与 BERT 类似,GPT2 由一系列的 GPT2 块组成。与 BERT 块不同,GPT2 块使用单向自注意力机制处理一些输入向量 y ′ 0 , … , y ′ m − 1 \mathbf{y’}_0, \ldots, \mathbf{y’}_{m-1} y ′ 0 ​ , … , y ′ m − 1 ​(在右下角的浅蓝色部分显示)到一个输出向量序列 y ′ ′ 0 , … , y ′ ′ m − 1 \mathbf{y”}_0, \ldots, \mathbf{y”}_{m-1} y ′ ′ 0 ​ , … , y ′ ′ m − 1 ​(在右上角的深蓝色部分显示)。除了 GPT2 块堆栈之外,该模型还有一个称为 LM Head 的线性层,该层将最后一个 GPT2 块的输出向量映射到逻辑向量 l 1 , … , l m \mathbf{l}_1, \ldots, \mathbf{l}_m l 1 ​ , … , l m ​。如前所述,逻辑向量 l i \mathbf{l}_i l i ​ 可以用于采样新的输入向量 y i \mathbf{y}_i y i ​ ^5 5 。

GPT2 主要用于开放域的文本生成。首先,将输入提示 Y 0 : i − 1 \mathbf{Y}_{0:i-1} Y 0 : i − 1 ​ 提供给模型,以产生条件分布 p θ gpt2 ( y ∣ Y 0 : i − 1 ) p_{\theta_{\text{gpt2}}}(\mathbf{y} | \mathbf{Y}_{0:i-1}) p θ gpt2 ​ ​ ( y ∣ Y 0 : i − 1 ​ ) 。然后从分布中采样下一个单词 y i \mathbf{y}_i y i ​(在图中由灰色箭头表示),并将其附加到输入中。以自回归的方式,可以从 p θ gpt2 ( y ∣ Y 0 : i ) p_{\theta_{\text{gpt2}}}(\mathbf{y} | \mathbf{Y}_{0:i}) p θ gpt2 ​ ​ ( y ∣ Y 0 : i ​ ) 中采样下一个单词 y i + 1 \mathbf{y}_{i+1} y i + 1 ​,以此类推。

因此,GPT2 非常适用于语言生成,但对于条件生成则不太适用。通过将输入提示 Y 0 : i − 1 \mathbf{Y}_{0: i-1} Y 0 : i − 1 ​ 设置为序列输入 X 1 : n \mathbf{X}_{1:n} X 1 : n ​ ,可以很好地用于条件生成。然而,与编码器-解码器架构相比,该模型架构存在根本缺陷,如 Raffel 等人在2019年的第17页所解释的。简而言之,单向自注意力强制模型对序列输入 X 1 : n \mathbf{X}_{1:n} X 1 : n ​ 的表示受到不必要的限制,因为 x i \mathbf{x}_i x i ​ 不能依赖于 x i + 1 ,∀ i ∈ { 1 , … , n } \mathbf{x}_{i+1}, \forall i \in \{1,\ldots, n\} x i + 1 ​ , ∀ i ∈ { 1 , … , n } 。

编码器-解码器

由于仅编码器模型要求事先知道输出长度,它们似乎不适用于序列到序列的任务。仅解码器模型可以很好地完成序列到序列的任务,但也存在一定的架构限制,如上所述。

目前应对序列到序列任务的主要方法是基于Transformer的编码器-解码器模型,通常也称为seq2seq Transformer模型。编码器-解码器模型最早在Vaswani等人(2017)的论文中提出,自那以后在序列到序列任务上的表现优于独立语言模型(即仅解码器模型),例如Raffel等人(2020)的论文。本质上,编码器-解码器模型是一个独立的编码器(如BERT)和一个独立的解码器模型(如GPT2)的组合。关于基于Transformer的编码器-解码器模型的详细架构,请参阅此博文。

现在,我们知道大型预训练独立编码器和解码器模型(如BERT和GPT)的免费检查点可以提高性能并降低训练成本,适用于许多自然语言理解任务。我们还知道编码器-解码器模型本质上是独立编码器和解码器模型的组合。这自然引出了一个问题,即如何利用独立模型检查点来构建编码器-解码器模型,以及哪些模型组合在特定的序列到序列任务上性能最佳。

在2020年,Sascha Rothe、Shashi Narayan和Aliaksei Severyn在他们的论文《利用预训练检查点进行序列生成任务》中对这个问题进行了详细研究。该论文对不同的编码器-解码器模型组合和微调技术进行了深入分析,我们将在后面更详细地研究。

将预训练独立模型检查点组合成编码器-解码器模型被定义为对编码器-解码器模型进行热启动。下面的章节将展示热启动编码器-解码器模型在理论上的工作原理,以及如何在实践中使用🤗Transformers来实现,还提供了提高性能的实用技巧。


1 {}^1 1 预训练语言模型被定义为一个神经网络:

  • 它是在无标签文本数据上进行训练的,即以任务无关、无监督的方式进行训练,
  • 它将一系列输入词语处理成上下文相关的嵌入。例如,Mikolov等人(2013)的连续词袋模型和跳字模型不被视为预训练语言模型,因为这些嵌入是上下文无关的。

2 {}^2 2 微调被定义为对已初始化为预训练语言模型权重的模型进行任务特定的训练。

3 {}^3 3 输入向量 y 0 \mathbf{y}_0 y 0 ​ 对应于用于预测第一个输出词 y 1 \mathbf{y}_1 y 1 ​ 所需的BOS \text{BOS} BOS 嵌入向量。

4 {}^4 4 为了不使方程和插图过于混乱,我们在此忽略了归一化层。

5 {}^5 5 关于为什么在”仅解码器”模型(如GPT2)中使用单向自注意力以及抽样的具体工作原理,请参阅编码器-解码器博文中的解码器部分。

编码器-解码器模型的热启动(理论)

通过阅读介绍,我们现在熟悉了仅编码器模型和仅解码器模型。我们已经注意到编码器-解码器模型架构本质上是一个独立的编码器模型和一个独立的解码器模型的组合,这使我们产生了如何从独立模型检查点热启动编码器-解码器模型的问题。

有多种可能性来热启动编码器-解码器模型。可以:

  1. 从仅编码器模型检查点(如BERT)初始化编码器和解码器部分,
  2. 从仅编码器模型检查点(如BERT)初始化编码器部分,从仅解码器模型检查点(如GPT2)初始化解码器部分,
  3. 仅从仅编码器模型检查点初始化编码器部分,或
  4. 仅从仅解码器模型检查点初始化解码器部分。

接下来,我们将重点讨论可能性1和2。在理解了前两种可能性之后,可能性3和4就变得微不足道。

回顾编码器-解码器模型

首先,让我们快速回顾一下编码器-解码器架构。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第4张

编码器(显示为绿色)是一堆编码器块的堆叠。每个编码器块由一个双向自注意力层和两个前馈层组成1 {}^1 1 。解码器(显示为橙色)是一堆解码器块,后面跟着一个称为LM Head的密集层。每个解码器块由一个单向自注意力层、一个交叉注意力层和两个前馈层组成。

编码器将输入序列 X 1 : n \mathbf{X}_{1:n} X 1 : n ​ 映射到上下文编码序列 X ‾ 1 : n \mathbf{\overline{X}}_{1:n} X 1 : n ​ ,方式与BERT完全相同。然后,解码器将上下文编码序列 X ‾ 1 : n \mathbf{\overline{X}}_{1:n} X 1 : n ​ 和目标序列 Y 0 : m − 1 \mathbf{Y}_{0:m-1} Y 0 : m − 1 ​ 映射到逻辑向量 L 1 : m \mathbf{L}_{1:m} L 1 : m ​ 。与GPT2类似,然后使用这些逻辑向量来定义目标序列 Y 1 : m \mathbf{Y}_{1:m} Y 1 : m ​ 在给定输入序列 X 1 : n \mathbf{X}_{1:n} X 1 : n ​ 的条件下的分布,通过softmax操作。

用数学术语来说,首先,条件分布通过贝叶斯规则将 m − 1 m – 1 m − 1 个下一个词 y i \mathbf{y}_i y i ​ 的条件分布拆分。

p θ enc, dec ( Y 1 : m ∣ X 1 : n ) = p θ dec ( Y 1 : m ∣ X ‾ 1 : n ) = ∏ i = 1 m p θ dec ( y i ∣ Y 0 : i − 1 , X ‾ 1 : n ) ,其中 X ‾ 1 : n = f θ enc ( X 1 : n ) 。

每个“下一个词”的条件分布由以下逻辑向量的softmax定义。

p θ dec ( y i ∣ Y 0 : i − 1 , X ‾ 1 : n ) = Softmax ( l i ) 。

更多详细信息,请参考Encoder-Decoder笔记本。

使用BERT进行Encoder-Decoder的热启动

现在让我们来说明如何使用预训练的BERT模型来热启动Encoder-Decoder模型。BERT的预训练权重参数被用于初始化编码器和解码器的权重参数。为此,将BERT的架构与编码器的架构进行比较,并且编码器中与BERT存在的所有层将使用相应层的预训练权重参数进行初始化。编码器中不存在的所有层将简单地使用随机初始化的权重参数。

让我们进行可视化。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第5张

我们可以看到,编码器的架构与BERT的架构一一对应。所有编码器块的双向自注意层和两个前馈层的权重参数都使用相应BERT块的预训练权重参数进行初始化。这以第二个编码器块为例进行了说明(底部的红色框),其权重参数θ enc self-attn, 2和θ enc feed-forward, 2分别被设置为BERT的权重参数θ BERT feed-forward, 2和θ BERT self-attn, 2在初始化时。

在微调之前,编码器的行为与预训练的BERT模型完全相同。假设传递给编码器的输入序列x1,…,xn(绿色表示)等于传递给BERT的输入序列x1 BERT,…,xn BERT(灰色表示),这意味着相应的输出向量序列x̄1,…,x̄n(较深的绿色表示)和x̄1 BERT,…,x̄n BERT(较深的灰色表示)也必须相等。

接下来,让我们说明如何热启动解码器。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第6张

解码器的架构与BERT的架构有三个不同之处。

  1. 首先,解码器必须根据上下文化的编码序列X̄1:n进行条件化,通过交叉注意力层。因此,在每个BERT块的自注意层和两个前馈层之间添加了随机初始化的交叉注意力层。这在第二个块中通过+θ dec cross-attention, 2表示,并在右侧的下方红色框中以新增的全连接图形表示。这必然改变了每个修改后的BERT块的行为,以至于输入向量(例如y′0)现在会产生随机的输出向量y′′0(用红色边框标出)。

  2. 其次,BERT的双向自注意层必须改为单向自注意层,以符合自回归生成的要求。由于双向自注意层和单向自注意层都基于相同的键、查询和值投影权重,解码器的自注意层权重可以使用BERT的自注意层权重进行初始化。例如,解码器的单向自注意层的查询、键和值权重参数使用BERT的双向自注意层的相应权重参数进行初始化,即θ BERT self-attn, 2 = { W BERT, k self-attn, 2, W BERT, v self-attn, 2, W BERT, q self-attn, 2 } → θ dec self-attn, 2 = { W dec, k self-attn, 2, W dec, v self-attn, 2, W dec, q self-attn, 2 }。

  3. 第三,解码器输出一个逻辑向量序列L1:m,以定义条件概率分布pθ dec(Y1:n | X̄)。因此,在最后一个解码器块的顶部添加了一个LM Head层。LM Head层的权重参数通常对应于词嵌入Wemb的权重参数,因此不会进行随机初始化。这在顶部的初始化中以θ BERT word-emb → θ dec lm-head表示。

总之,当从预训练的BERT模型启动解码器时,只有交叉注意力层的权重被随机初始化。所有其他权重,包括自注意力层和语言模型头的权重,都使用BERT的预训练权重参数进行初始化。

在热启动了编码器-解码器模型之后,权重会在序列到序列的下游任务(例如摘要)上进行微调。

使用BERT和GPT2进行编码器-解码器的热启动

我们可以使用BERT的检查点来为编码器热启动,同时使用GPT2的检查点来为解码器热启动,而不是使用BERT检查点来为编码器和解码器都进行热启动。乍一看,只使用GPT2检查点来热启动解码器似乎更合适,因为它已经在因果语言建模上进行了训练,并使用了单向自注意力层。

让我们说明一下如何使用GPT2检查点来热启动解码器。

利用预训练的语言模型检查点来构建编码器-解码器模型 四海 第7张

我们可以看到解码器与GPT2更相似,而不是BERT。解码器的语言模型头的权重参数可以直接使用GPT2的语言模型头的权重参数进行初始化,例如 θ GPT2 lm-head → θ dec lm-head \theta_{\text{GPT2}}^{\text{lm-head}} \to \theta_{\text{dec}}^{\text{lm-head}}。此外,解码器和GPT2的块都使用单向自注意力,因此假设输入向量相同,解码器的自注意力层的输出向量与GPT2的输出向量是相等的,例如 y ′ 0 GPT2 = y ′ 0 \mathbf{y’}_0^{\text{GPT2}} = \mathbf{y’}_0。

与使用BERT初始化的解码器相比,使用GPT2初始化的解码器保留了自注意力层的因果连接图,如底部的红色框所示。

然而,使用GPT2初始化的解码器也必须根据 X ‾ 1 : n \mathbf{\overline{X}}_{1:n} X 1 : n ​ 对解码器进行条件操作。因此,在每个解码器块中添加了随机初始化的交叉注意力层的权重参数。例如,对于第二个编码器块,表示为 + θ dec cross-attention, 2 +\theta_{\text{dec}}^{\text{cross-attention, 2}} + θ dec cross-attention, 2 ​。

尽管GPT2与编码器-解码器模型的解码器部分更相似,但使用GPT2初始化的解码器在没有微调的情况下也会产生随机的逻辑向量 L 1 : m \mathbf{L}_{1:m} L 1 : m ​。这是由于每个解码器块中的交叉注意力层是随机初始化的。有趣的是,研究一下使用GPT2初始化的解码器是否能够产生更好的结果或更有效地进行微调。

编码器-解码器权重共享

在Raffel et al. (2020)中,作者表明,一个随机初始化的编码器-解码器模型,通过与解码器共享编码器的权重,从而将内存占用减少一半,其性能只略低于其“非共享”版本。与解码器共享编码器的权重意味着在相同位置找到的解码器的所有层共享相同的权重参数,即网络计算图中的相同节点。例如,第三个编码器块中自注意力层的查询、键和值投影矩阵,定义为 W Enc , k self-attn , 3 \mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, k} W Enc , k self-attn , 3 ​,W Enc , v self-attn , 3 \mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, v} W Enc , v self-attn , 3 ​,W Enc , q self-attn , 3 \mathbf{W}^{\text{self-attn}, 3}_{\text{Enc}, q} W Enc , q self-attn , 3 ​与第三个解码器块中自注意力层的相应查询、键和值投影矩阵是相同的。

W k self-attn , 3 = W enc , k self-attn , 3 ≡ W dec , k self-attn , 3 , \mathbf{W}^{\text{self-attn}, 3}_{k} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, k} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, k}, W k self-attn , 3 ​ = W enc , k self-attn , 3 ​ ≡ W dec , k self-attn , 3 ​ , W q self-attn , 3 = W enc , q self-attn , 3 ≡ W dec , q self-attn , 3 , \mathbf{W}^{\text{self-attn}, 3}_{q} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, q} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, q}, W q self-attn , 3 ​ = W enc , q self-attn , 3 ​ ≡ W dec , q self-attn , 3 ​ , W v self-attn , 3 = W enc , v self-attn , 3 ≡ W dec , v self-attn , 3 , \mathbf{W}^{\text{self-attn}, 3}_{v} = \mathbf{W}^{\text{self-attn}, 3}_{\text{enc}, v} \equiv \mathbf{W}^{\text{self-attn}, 3}_{\text{dec}, v}, W v self-attn , 3 ​ = W enc , v self-attn , 3 ​ ≡ W dec , v self-attn , 3 ​ ,

因此,关键投影权重 W k self-attn , 3 , W v self-attn , 3 , W q self-attn , 3 \mathbf{W}^{\text{self-attn}, 3}_{k}, \mathbf{W}^{\text{self-attn}, 3}_{v}, \mathbf{W}^{\text{self-attn}, 3}_{q} W k self-attn , 3 ​ , W v self-attn , 3 ​ , W q self-attn , 3 ​ 在每个反向传播过程中更新两次 – 一次是当梯度通过第三个解码器块反向传播时,一次是当梯度通过第三个编码器块反向传播时。

同样地,我们可以通过共享编码器权重来启动一个编码器-解码器模型。能够在编码器和解码器之间共享权重要求解码器的架构(不包括交叉注意力权重)与编码器的架构完全相同。因此,只有当编码器-解码器模型从单个仅编码器预训练检查点启动时,编码器-解码器权重共享才是相关的。

太好了!那就是关于启动编码器-解码器模型的理论。现在让我们看一些结果。


1 {}^1 1 为了不使方程和插图混乱,我们在这里排除了归一化层。2 {}^2 2 关于自注意力层如何工作的更多细节,请参考变压器编码器-解码器模型博文的这部分(以及相应的解码器部分)。

启动编码器-解码器模型(分析)

在本节中,我们将总结 Sascha Rothe、Shashi Narayan 和 Aliaksei Severyn 在《利用预训练检查点进行序列生成任务》中提出的关于启动编码器-解码器模型的发现。作者们将启动过的编码器-解码器模型与随机初始化的编码器-解码器模型在多个序列到序列任务上进行了比较,特别是摘要、翻译、句子分割和句子合并。

更准确地说,使用了公开可用的预训练检查点 BERTRoBERTaGPT2 的不同变体来启动编码器-解码器模型。例如,BERT 初始化的编码器与 BERT 初始化的解码器配对,形成 BERT2BERT 模型,或者 RoBERTa 初始化的编码器与 GPT2 初始化的解码器配对,形成 RoBERTa2GPT2 模型。此外,还研究了共享编码器和解码器权重的效果(如前一节所述),对 RoBERTa 进行了调查,即 RoBERTaShare ,以及对 BERT 进行了调查,即 BERTShare 。随机或部分随机初始化的编码器-解码器模型被用作基准,例如完全随机初始化的编码器-解码器模型,称为 Rnd2Rnd ,或者 BERT 初始化的解码器与随机初始化的编码器配对,定义为 Rnd2BERT

下表显示了所有调查的模型变体的完整列表,包括随机初始化权重的数量,即“random”,以及从相应的预训练检查点初始化的权重的数量,即“leveraged”。所有模型都基于一个12层的架构,具有768维的隐藏大小嵌入,对应于🤗Transformers模型库中的bert-base-casedbert-base-uncasedroberta-basegpt2检查点。

基于BERT2BERT架构的模型Rnd2Rnd包含221M个权重参数,全部是随机初始化的。另外两个“基于BERT的”基准模型Rnd2BERT和BERT2Rnd大约有一半的权重,即112M个参数,是随机初始化的。其他109M个权重参数分别从预训练的bert-base-uncased检查点中获取,用于编码器或解码器部分。模型BERT2BERT、BERT2GPT2和RoBERTa2GPT2的所有编码器权重参数都是从bert-base-uncasedroberta-base检查点中获取的,大部分解码器权重参数也是如此,分别从gpt2bert-base-uncased检查点中获取。其中26M个解码器权重参数,对应于12个交叉关注层,是随机初始化的。RoBERTa2GPT2和BERT2GPT2与Rnd2GPT2基准进行了比较。此外,值得注意的是,共享模型变体BERTShare和RoBERTaShare的参数数量明显较少,因为所有编码器权重参数与相应的解码器权重参数是共享的。

实验

上述模型在四个逐渐复杂的序列到序列任务上进行了训练和评估:句子级融合、句子级拆分、翻译和抽象摘要。下表显示了每个任务使用的数据集。

根据任务的不同,使用了稍微不同的训练方案。例如,根据数据集的大小和特定任务,训练步骤的数量范围从200K到500K,批量大小设置为128或256,输入长度范围从128到512,输出长度在32到128之间变化。然而,应当强调的是,在每个任务中,所有模型都使用相同的超参数进行训练和评估,以确保公平比较。有关任务特定的超参数设置的更多信息,请参阅论文中的实验部分。

现在我们将对每个任务的结果进行简要概述。

句子融合和拆分(DiscoFuse,WikiSplit)

句子融合是将多个句子合并为一个连贯的句子的任务。例如,以下两个句子:

作为一个阻截手,Zeitler的移动相对不错。Zeitler在空间中的接触点经常遇到困难。

应该用一个合适的连接词将它们连接起来,例如:

作为一个阻截手,Zeitler的移动相对不错。然而,他在空间中的接触点经常遇到困难。

可以看到,连接词“然而”提供了从第一句到第二句的连贯过渡。能够生成这样的连接词的模型可以说已经学会推断以上两个句子相互对比。

相反的任务称为句子拆分,它包括将单个复杂句子拆分成多个较简单的句子,这些句子共同保留相同的意思。句子拆分被认为是文本简化中的一个重要任务,参见Botha等人(2018)。

例如,下面的句子:

Street Rod是1989年为PC和Commodore 64发布的两个游戏系列中的第一个游戏

可以简化为

Street Rod是1989年为PC和Commodore 64发布的两个游戏它被看到,长句试图传达两个重要的信息。一个是该游戏是为PC发布的两个游戏中的第一个,而第二个是发布的年份。因此,句子拆分需要模型理解应该将句子的哪一部分分成两个句子,使任务比句子融合更加困难。

评估模型在句子融合和拆分任务上的常见指标是SARI(Wu等人(2016)),它广泛基于标签和模型输出的F1分数。

让我们看看模型在句子融合和拆分上的表现。

前两列显示了DiscoFuse评估数据上编码器-解码器模型的性能。第一列列出了在所有(100%)训练数据上训练的编码器-解码器模型的结果,而第二列显示了仅在10%训练数据上训练的模型的结果。我们观察到,启动热模型的性能显著优于随机初始化的基线模型Rnd2Rnd,Rnd2Bert和Rnd2GPT2。只在10%训练数据上训练的启动热的RoBERTa2GPT2模型与在100%训练数据上训练的Rnd2Rnd模型相当。有趣的是,Bert2Rnd基线的表现与完全启动热的Bert2Bert模型一样好,这表明启动热编码器部分比启动热解码器部分更有效。最佳结果由RoBERTa2GPT2获得,其次是RobertaShare。共享编码器和解码器的权重参数似乎会轻微增加模型的性能。

在更困难的句子拆分任务中,出现了类似的模式。启动热的编码器-解码器模型明显优于其编码器随机初始化的编码器-解码器模型,而具有共享权重参数的编码器-解码器模型的结果比具有解耦权重参数的模型更好。在句子拆分中,BertShare模型的性能最佳,紧随其后的是RobertaShare。

除了12层模型变体外,作者还训练和评估了一个24层的RobertaShare(large)模型,其性能显著优于所有12层模型。

机器翻译(WMT14)

接下来,作者在机器翻译(MT)中最常见的基准测试之一 – En → De和De → En WMT14数据集上评估了启动热的编码器-解码器模型。在本笔记本中,我们展示了在newstest2014 eval数据集上的结果。因为基准测试要求模型理解英语和德语词汇,所以BERT初始化的编码器-解码器模型是从多语言预训练检查点bert-base-multilingual-cased启动热的。由于没有公开可用的多语言RoBERTa检查点,因此RoBERTa初始化的编码器-解码器模型被排除在MT之外。GPT2初始化的模型是从先前实验中的gpt2预训练检查点初始化的。使用BLUE-4分数度量报告翻译结果1 {}^1 1。

同样,我们观察到通过启动热的编码器部分可以显著提升性能,在En → De和De → En任务中,BERT2Rnd和BERT2BERT的结果最佳。与Rnd2Rnd基线相比,GPT2初始化的模型在En → De上的性能甚至更差。考虑到gpt2检查点仅在英文文本上训练,BERT2GPT2和Rnd2GPT2模型在生成德语翻译时存在困难也就不足为奇了。这一假设得到了BERT2GPT2在De → En任务上的竞争结果(例如31.4 vs. 32.7)的支持,其中GPT2的词汇符合英文输出格式。与在句子融合和句子拆分上获得的结果相反,共享编码器和解码器的权重参数在MT中并不能提高性能。作者指出这可能的原因包括

  • 编码器-解码器模型容量在MT中是一个重要因素,以及
  • 编码器和解码器必须处理不同的语法和词汇

由于bert-base-multilingual-cased检查点训练了100多种语言,其词汇表对于En → De和De → En MT可能过大。因此,作者预训练了一个大型的BERT仅编码器检查点,用于英语和德语维基百科子集,并随后用它来启动热BERT2Rnd和BERTShare编码器-解码器模型。由于词汇表的改进,观察到了另一个显着的性能提升,BERT2Rnd(large, custom)明显优于所有其他模型。

摘要(CNN/Dailymail,BBC XSum,Gigaword)

最后,编码器-解码器模型在可能是最具挑战性的序列到序列任务-摘要中进行了评估。作者选择了三个具有不同特点的摘要数据集进行评估:Gigaword(标题生成)、BBC XSum(极端摘要)和CNN/Dailymayl(抽象摘要)。

Gigaword数据集包含句子级的抽象摘要,要求模型学习句子级的理解、抽象和重述。Gigaword中的一个典型数据样本,例如

“*委内瑞拉总统乌戈·查韦斯周四表示,他已下令调查涉嫌卷入现役和退役军官的政变阴谋。*”,

将具有相应的标题作为其标签,例如:

“查韦斯下令调查涉嫌政变阴谋”。

BBC XSum数据集包含更长的类似文章的文本输入,其标签大多是单句摘要。该数据集要求模型不仅学习文档级的推理,还要具备高水平的抽象性重述能力。BBC XSUM数据集的一些数据样本如下所示。

对于CNN/Dailmail数据集,文档的长度与BBC XSum数据集中的文档长度相似,需要将文档摘要为要点故事的重点。因此,标签通常包含多个句子。除了文档级的理解,CNN/Dailymail数据集还要求模型能够很好地复制最显著的信息。一些示例可以在这里查看。

模型使用Rouge指标进行评估,下面显示了Rouge-2分数。

好了,让我们来看看结果。

我们再次观察到,对编码器部分进行预热可以显著提高性能,而与随机初始化编码器的模型相比,这在文档级抽象任务(即CNN/Dailymail和BBC XSum)中尤为明显。这表明,对于需要高级抽象能力的任务,即使只需要句子级抽象能力的任务,预训练的编码器部分对于预测效果的改进更为明显。除了Gigaword之外,基于GPT2的编码器-解码器模型似乎不适用于摘要。

此外,共享的编码器-解码器模型是摘要任务中表现最佳的模型。RoBERTaShare和BERTShare是所有数据集上表现最佳的模型,而在BBC XSum数据集上,RoBERTaShare(large)的优势尤为显著,超过BERT2BERT和BERT2Rnd约3个Rouge-2分数,超过Rnd2Rnd超过8个Rouge-2分数。正如作者所提出的,“这可能是因为BBC摘要句子的分布与文档中的句子的分布相似,而对于Gigaword的标题和CNN/DailyMail的要点摘要来说,情况并非必然如此”。直观地讲,这意味着在BBC XSum中,编码器处理的输入句子在结构上与解码器处理的单句摘要非常相似,即长度相同,选择的词汇相似,语法相似。

结论

好了,让我们得出结论并尝试得出一些实用的提示。

  • 我们观察到,在所有任务中,与具有随机初始化编码器的编码器-解码器模型相比,预热的编码器部分显著提高了性能。另一方面,预热解码器似乎不太重要,BERT2BERT在大多数任务上与BERT2Rnd不相上下。一个直观的原因是,由于BERT或RoBERTa初始化的编码器部分没有任何权重参数是随机初始化的,因此编码器可以充分利用BERT或RoBERTa的预训练检查点所获得的知识。相比之下,预热的解码器始终有部分权重参数是随机初始化的,这可能使得解码器很难有效地利用用于初始化解码器的检查点所获得的知识。

  • 接下来,我们注意到共享编码器和解码器权重通常是有益的,特别是如果目标分布与输入分布相似(例如BBC XSum)。然而,对于目标数据分布与输入数据分布更显著不同,并且已知模型容量2 {}^2 2在其中发挥重要作用的数据集,例如WMT14,共享编码器-解码器权重似乎是不利的。

  • 最后,我们看到,预训练的“独立”检查点的词汇表与解决序列到序列任务所需的词汇表非常重要。例如,预热的BERT2GPT2编码器-解码器在En → De MT上表现不佳,因为GPT2是在英语上预训练的,而目标语言是德语。与BERT2BERT、BERTShared和RoBERTaShared相比,BERT2GPT2、Rnd2GPT2和RoBERTa2GPT2的整体性能较差,这表明共享词汇表更有效。此外,这也表明,将解码器部分初始化为预训练的GPT2检查点并不比使用预训练的BERT检查点初始化解码器更有效,尽管GPT2在架构上更类似于解码器。

对于上述每个任务,最有效的模型已被移植到 🤗Transformers 并可以在此处访问:

  • RoBERTaShared(大型)- Wikisplit:google/roberta2roberta_L-24_wikisplit。
  • RoBERTaShared(大型)- Discofuse:google/roberta2roberta_L-24_discofuse。
  • BERT2BERT(大型)- WMT en → de:google/bert2bert_L-24_wmt_en_de。
  • BERT2BERT(大型)- WMT de → en:google/bert2bert_L-24_wmt_de_en。
  • RoBERTaShared(大型)- CNN/Dailymail:google/roberta2roberta_L-24_cnn_daily_mail。
  • RoBERTaShared(大型)- BBC XSum:google/roberta2roberta_L-24_bbc。
  • RoBERTaShared(大型)- Gigaword:google/roberta2roberta_L-24_gigaword。

1 {}^1 1 为了获取 BLEU-4 分数,使用了来自 Tensorflow 官方 Transformer 实现的脚本 https://github.com/tensorflow/models/tree master/official/nlp/transformer。请注意,与 Vaswani 等人使用的 tensor2tensor/utils/ get_ende_bleu.sh 不同,该脚本不会分开名词复合词,但在注意到预处理后的训练集仅包含 ascii 引号后,utf-8 引号被规范化为 ascii 引号。

2 {}^2 2 模型容量是对模型在建模复杂模式方面的好坏进行非正式定义。有时它也被定义为模型从越来越多的数据中学习的能力。模型容量通常通过可训练参数的数量来衡量-参数越多,模型容量越高。

我们已经解释了暖启动编码器-解码器模型的理论,分析了多个数据集上的实证结果,并得出了实际结论。现在让我们通过一个完整的代码示例来展示如何在 CNN/Dailymail 摘要任务上暖启动并继续微调一个 BERT2BERT 模型。我们将利用 🤗datasets 和 🤗Transformers 库。

此外,以下列表提供了有关暖启动其他编码器-解码器模型组合的此笔记本和其他笔记本的精简版本。

  • 对于 CNN/Dailymail 上的 BERT2BERT(此笔记本的精简版本),单击此处。
  • 对于 BBC XSum 上的 RoBERTaShare,单击此处。
  • 对于 WMT14 En → De 上的 BERT2Rnd,单击此处。
  • 对于 DiscoFuse 上的 RoBERTa2GPT2,单击此处。

注意:此笔记本仅使用少量训练、验证和测试数据样本进行演示目的。要在完整的训练数据上对编码器-解码器模型进行微调,用户应根据注释中的高亮部分相应地更改训练和数据预处理参数。

数据预处理

在本节中,我们将展示如何为训练预处理数据。更重要的是,我们试图为读者提供决定如何预处理数据的过程的一些见解。

我们将需要安装 datasets 和 transformers。

!pip install datasets==1.0.2
!

输入数据似乎由短新闻文章组成。有趣的是,标签似乎是类似于项目符号的摘要。此时,我们可能需要查看其他几个示例,以更好地了解数据。

在这里还应该注意到文本是区分大小写的。这意味着如果我们想使用不区分大小写的模型,就必须小心。由于CNN/Dailymail是一个摘要数据集,模型将使用ROUGE指标进行评估。在🤗数据集的描述中检查ROUGE,详见这里,我们可以看到该指标是不区分大小写的,这意味着在评估过程中大写字母将被归一化为小写字母。因此,我们可以安全地使用不区分大小写的检查点,如bert-base-uncased

很棒!接下来,让我们对输入数据和标签的长度有一个概念。

由于模型计算长度时使用的是令牌长度,我们将使用bert-base-uncased分词器来计算文章和摘要的长度。

首先,我们加载分词器。

from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

接下来,我们使用.map()来计算文章和摘要的长度。由于我们知道bert-base-uncased可以处理的最大长度为512个令牌,我们还对输入样本中超过最大长度的百分比感兴趣。同样地,我们计算长度超过64和128的摘要的百分比。

我们可以定义.map()函数如下。

# 将文章和摘要长度映射为字典,同时记录样本是否超过512个令牌
def map_to_length(x):
  x["article_len"] = len(tokenizer(x["article"]).input_ids)
  x["article_longer_512"] = int(x["article_len"] > 512)
  x["summary_len"] = len(tokenizer(x["highlights"]).input_ids)
  x["summary_longer_64"] = int(x["summary_len"] > 64)
  x["summary_longer_128"] = int(x["summary_len"] > 128)
  return x

我们只需要查看前10000个样本就足够了。我们可以通过使用num_proc=4来加快映射的速度。

sample_size = 10000
data_stats = train_data.select(range(sample_size)).map(map_to_length, num_proc=4)

计算了前10000个样本的长度后,我们现在需要将它们的平均值计算在一起。为此,我们可以使用.map()函数,并设置batched=Truebatch_size=-1,以便在.map()函数中访问所有10000个样本。

def compute_and_print_stats(x):
  if len(x["article_len"]) == sample_size:
    print(
        "文章平均长度:{},超过512个令牌的样本百分比:{},摘要平均长度:{},超过64个令牌的摘要百分比:{},超过128个令牌的摘要百分比:{}".format(
            sum(x["article_len"]) / sample_size,
            sum(x["article_longer_512"]) / sample_size, 
            sum(x["summary_len"]) / sample_size,
            sum(x["summary_longer_64"]) / sample_size,
            sum(x["summary_longer_128"]) / sample_size,
        )
    )

output = data_stats.map(
  compute_and_print_stats, 
  batched=True,
  batch_size=-1,
)

    输出:
    -------
    文章平均长度:847.6216,超过512个令牌的样本百分比:0.7355,摘要平均长度:57.7742,超过64个令牌的摘要百分比:0.3185,超过128个令牌的摘要百分比:0.0

我们可以看到平均而言,一篇文章包含848个令牌,约有3/4的文章长度超过模型的max_length 512。摘要平均长度为57个令牌。超过30%的10000个样本摘要长度超过64个令牌,但没有一个摘要长度超过128个令牌。

bert-base-cased限制为512个令牌,这意味着我们可能需要从文章中删除可能重要的信息。由于大部分重要信息通常位于文章的开头,并且我们希望在计算上高效,我们决定在此笔记本中继续使用bert-base-cased,并将max_length设置为512。这个选择不是最优的,但已经证明在CNN/Dailymail上取得了不错的结果。另外,我们也可以使用长距离序列模型,比如Longformer作为编码器。

关于摘要长度,我们可以看到长度为128已经包括了所有的摘要标签。128在bert-base-cased的限制范围内,因此我们决定将生成限制在128内。

同样,我们将使用.map()函数 - 这次是将每个训练批次转换为模型输入的批次。

"article""highlights"被标记并准备为编码器的"input_ids"和解码器的"decoder_input_ids"

"labels"会自动向左移动一个位置,用于语言建模训练。

最后,非常重要的是要忽略填充标签的损失。在🤗Transformers中,可以通过将标签设置为-100来实现。太好了,让我们写下我们的映射函数。

encoder_max_length=512
decoder_max_length=128

def process_data_to_model_inputs(batch):
  # 对输入和标签进行标记化
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # 因为BERT自动将标签向左移动,所以标签与`decoder_input_ids`完全对应。
  # 我们必须确保忽略PAD标记
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

在这个笔记本中,我们只使用了一些训练样本对模型进行训练和评估,以示范,并将batch_size设置为4,以防止内存不足的问题。

以下代码将训练数据缩减为前32个样本。可以将此段代码注释掉或不运行以进行完整的训练。使用batch_size为16时获得了不错的结果。

train_data = train_data.select(range(32))

好的,让我们准备训练数据。

# batch_size = 16
batch_size=4

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

查看处理后的训练数据集,我们可以看到列名articlehighlightsid已经被替换为EncoderDecoderModel所期望的参数。

train_data

输出:
-------
Dataset(features: {'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}, num_rows: 32)

到目前为止,数据是使用Python的List格式进行操作的。让我们将数据转换为PyTorch Tensors以在GPU上进行训练。

train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"],
)

太棒了,训练数据的数据处理已经完成。类似地,我们可以对验证数据进行相同的处理。

首先,我们加载验证数据集的10%:

val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")

为了演示目的,然后将验证数据减少到只有8个样本,

val_data = val_data.select(range(8))

应用映射函数,

val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

最后,验证数据也被转换为PyTorch张量。

val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"],
)

太棒了!现在我们可以开始预热EncoderDecoderModel

预热Encoder-Decoder模型

本节介绍了如何使用bert-base-cased检查点来预热Encoder-Decoder模型。

我们首先导入EncoderDecoderModel。有关EncoderDecoderModel类的更详细信息,请参阅文档。

from transformers import EncoderDecoderModel

与🤗Transformers中的其他模型类不同,EncoderDecoderModel类有两种加载预训练权重的方法:

  1. “标准”的.from_pretrained(...)方法派生自通用的PretrainedModel.from_pretrained(...)方法,因此与其他模型类的方法完全相同。该函数期望一个模型标识符,例如.from_pretrained("google/bert2bert_L-24_wmt_de_en"),并将一个.pt检查点文件加载到EncoderDecoderModel类中。

  2. 特殊的.from_encoder_decoder_pretrained(...)方法可以用于从两个模型标识符预热编码器-解码器模型 - 一个用于编码器,一个用于解码器。第一个模型标识符用于通过AutoModel.from_pretrained(...)(请参阅此处的文档)加载编码器,第二个模型标识符用于通过AutoModelForCausalLM(请参阅此处的文档)加载解码器。

好的,让我们预热我们的BERT2BERT模型。如前所述,我们将使用"bert-base-uncased"检查点来预热编码器和解码器。

bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

输出:
-------
"""在初始化BertLMHeadModel时,模型检查点bert-base-uncased的一些权重未被使用:['cls.seq_relationship.weight','cls.seq_relationship.bias']
    - 如果您正在从在另一个任务上训练或使用另一种架构的模型的检查点初始化BertLMHeadModel,则这是预期的(例如,从BertForPretraining模型初始化BertForSequenceClassification模型)。
    - 如果您正在从您希望完全相同的模型的检查点初始化BertLMHeadModel,则这是不预期的(例如,从BertForSequenceClassification模型初始化BertForSequenceClassification模型)。
    BertLMHeadModel的一些权重未从bert-base-uncased的模型检查点初始化,并且是新初始化的:['bert.encoder.layer.0.crossattention.self.query.weight','bert.encoder.layer.0.crossattention.self.query.bias','bert.encoder.layer.0.crossattention.self.key.weight','bert.encoder.layer.0.crossattention.self.key.bias','bert.encoder.layer.0.crossattention.self.value.weight','bert.encoder.layer.0.crossattention.self.value.bias','bert.encoder.layer.0.crossattention.output.dense.weight','bert.encoder.layer.0.crossattention.output.dense.bias','bert.encoder.layer.0.crossattention.output.LayerNorm.weight','bert.encoder.layer.0.crossattention.output.LayerNorm.bias','bert.encoder.layer.1.crossattention.self.query.weight','bert.encoder.layer.1.crossattention.self.query.bias','bert.encoder.layer.1.crossattention.self.key.weight','bert.encoder.layer.1.crossattention.self.key.bias','bert.encoder.layer.1.crossattention.self.value.weight','bert.encoder.layer.1.crossattention.self.value.bias','bert.encoder.layer.1.crossattention.output.dense.weight','bert.encoder.layer.1.crossattention.output.dense.bias','bert.encoder.layer.1.crossattention.output.LayerNorm.weight','bert.encoder.layer.1.crossattention.output.LayerNorm.bias','bert.encoder.layer.2.crossattention.self.query.weight','bert.encoder.layer.2.crossattention.self.query.bias','bert.encoder.layer.2.crossattention.self.key.weight','bert.encoder.layer.2.crossattention.self.key.bias','bert.encoder.layer.2.crossattention.self.value.weight','bert.encoder.layer.2.crossattention.self.value.bias','bert.encoder.layer.2.crossattention.output.dense.weight','bert.encoder.layer.2.crossattention.output.dense.bias','bert.encoder.layer.2.crossattention.output.LayerNorm.weight','bert.encoder.layer.2.crossattention.output.LayerNorm.bias','bert.encoder.layer.3.crossattention.self.query.weight','bert.encoder.layer.3.crossattention.self.query.bias','bert.encoder.layer.3.crossattention.self.key.weight','bert.encoder.layer.3.crossattention.self.key.bias','bert.encoder.layer.3.crossattention.self.value.weight','bert.encoder.layer.3.crossattention.self.value.bias','bert.encoder.layer.3.crossattention.output.dense.weight','bert.encoder.layer.3.crossattention.output.dense.bias','bert.encoder.layer.3.crossattention.output.LayerNorm.weight','bert.encoder.layer.3.crossattention.output.LayerNorm.bias','bert.encoder.layer.4.crossattention.self.query.weight','bert.encoder.layer.4.crossattention.self.query.bias','bert.encoder.layer.4.crossattention.self.key.weight','bert.encoder.layer.4.crossattention.self.key.bias','bert.encoder.layer.4.crossattention.self.value.weight','bert.encoder.layer.4.crossattention.self.value.bias','bert.encoder.layer.4.crossattention.output.dense.weight','bert.encoder.layer.4.crossattention.output.dense.bias','bert.encoder.layer.4.crossattention.output.LayerNorm.weight','bert.encoder.layer.4.crossattention.output.LayerNorm.bias','bert.encoder.layer.5.crossattention.self.query.weight','bert.encoder.layer.5.crossattention.self.query.bias','bert.encoder.layer.5.crossattention.self.key.weight','bert.encoder.layer.5.crossattention.self.key.bias','bert.encoder.layer.5.crossattention.self.value.weight','bert.encoder.layer.5.crossattention.self.value.bias','bert.encoder.layer.5.crossattention.output.dense.weight','bert.encoder.layer.5.crossattention.output.dense.bias','bert.encoder.layer.5.crossattention.output.LayerNorm.weight','bert.encoder.layer.5.crossattention.output.LayerNorm.bias','bert.encoder.layer.6.crossattention.self.query.weight','bert.encoder.layer.6.crossattention.self.query.bias','bert.encoder.layer.6.crossattention.self.key.weight','bert.encoder.layer.6.crossattention.self.key.bias','bert.encoder.layer.6.crossattention.self.value.weight','bert.encoder.layer.6.crossattention.self.value.bias','bert.encoder.layer.6.crossattention.output.dense.weight','bert.encoder.layer.6.crossattention.output.dense.bias','bert.encoder.layer.6.crossattention.output.LayerNorm.weight','bert.encoder.layer.6.crossattention.output.LayerNorm.bias','bert.encoder.layer.7.crossattention.self.query.weight','bert.encoder.layer.7.crossattention.self.query.bias','bert.encoder.layer.7.crossattention.self.key.weight','bert.encoder.layer.7.crossattention.self.key.bias','bert.encoder.layer.7.crossattention.self.value.weight','bert.encoder.layer.7.crossattention.self.value.bias','bert.encoder.layer.7.crossattention.output.dense.weight','bert

首先,我们应该仔细看一下这里的警告。我们可以看到,对应于"cls"层的两个权重没有被使用。这不应该是一个问题,因为对于序列到序列的任务,我们不需要BERT的CLS层。此外,我们注意到有很多权重是“新”或随机初始化的。仔细观察这些权重,我们会发现它们都对应于交叉注意力层,这正是我们在上面理论中所期望的。

让我们更仔细地看一下模型。

bert2bert

输出:
-------
    EncoderDecoderModel(
      (encoder): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            ),
                        ...
                        ,
            (11): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (activation): Tanh()
        )
      )
      (decoder): BertLMHeadModel(
        (bert): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0): BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (crossattention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): BertIntermediate(
                  (dense): Linear(in_features=768, out_features=3072, bias=True)
                )
                (output): BertOutput(
                  (dense): Linear(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              ),
                            ...,
              (11): BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (crossattention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                  (output): BertSelfOutput(
                    (dense): Linear(in_features=768, out_features=768, bias=True)
                    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                    (dropout): Dropout(p=0.1, inplace=False)
                  )
                )
                (intermediate): BertIntermediate(
                  (dense): Linear(in_features=768, out_features=3072, bias=True)
                )
                (output): BertOutput(
                  (dense): Linear(in_features=3072, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
            )
          )
        )
        (cls): BertOnlyMLMHead(
          (predictions): BertLMPredictionHead(
            (transform): BertPredictionHeadTransform(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            )
            (decoder): Linear(in_features=768, out_features=30522, bias=True)
          )
        )
      )
    )

我们看到bert2bert.encoderBertModel的一个实例,bert2bert.decoderBertLMHeadModel的一个实例。然而,这两个实例现在被合并成一个torch.nn.Module,因此可以保存为单个.pt检查点文件。

让我们尝试使用标准的.save_pretrained(...)方法来保存它。

bert2bert.save_pretrained("bert2bert")

类似地,可以使用标准的.from_pretrained(...)方法重新加载模型。

bert2bert = EncoderDecoderModel.from_pretrained("bert2bert")

太棒了。让我们也保存配置。

bert2bert.config

输出:
-------
    EncoderDecoderConfig {
      "_name_or_path": "bert2bert",
      "architectures": [
        "EncoderDecoderModel"
      ],
      "decoder": {
        "_name_or_path": "bert-base-uncased",
        "add_cross_attention": true,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "do_sample": false,
        "early_stopping": false,
        "eos_token_id": null,
        "finetuning_task": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": true,
        "is_encoder_decoder": false,
        "label2id": {
          "LABEL_0": 0,
          "LABEL_1": 1
        },
        "layer_norm_eps": 1e-12,
        "length_penalty": 1.0,
        "max_length": 20,
        "max_position_embeddings": 512,
        "min_length": 0,
        "model_type": "bert",
        "no_repeat_ngram_size": 0,
        "num_attention_heads": 12,
        "num_beams": 1,
        "num_hidden_layers": 12,
        "num_return_sequences": 1,
        "output_attentions": false,
        "output_hidden_states": false,
        "pad_token_id": 0,
        "prefix": null,
        "pruned_heads": {},
        "repetition_penalty": 1.0,
        "return_dict": false,
        "sep_token_id": null,
        "task_specific_params": null,
        "temperature": 1.0,
        "tie_encoder_decoder": false,
        "tie_word_embeddings": true,
        "tokenizer_class": null,
        "top_k": 50,
        "top_p": 1.0,
        "torchscript": false,
        "type_vocab_size": 2,
        "use_bfloat16": false,
        "use_cache": true,
        "vocab_size": 30522,
        "xla_device": null
      },
      "encoder": {
        "_name_or_path": "bert-base-uncased",
        "add_cross_attention": false,
        "architectures": [
          "BertForMaskedLM"
        ],
        "attention_probs_dropout_prob": 0.1,
        "bad_words_ids": null,
        "bos_token_id": null,
        "chunk_size_feed_forward": 0,
        "decoder_start_token_id": null,
        "do_sample": false,
        "early_stopping": false,
        "eos_token_id": null,
        "finetuning_task": null,
        "gradient_checkpointing": false,
        "hidden_act": "gelu",
        "hidden_dropout_prob": 0.1,
        "hidden_size": 768,
        "id2label": {
          "0": "LABEL_0",
          "1": "LABEL_1"
        },
        "initializer_range": 0.02,
        "intermediate_size": 3072,
        "is_decoder": false,
        "is_encoder_decoder": false,
        "label2id": {
          "LABEL_0": 0,
          "LABEL_1": 1
        },
        "layer_norm_eps": 1e-12,
        "length_penalty": 1.0,
        "max_length": 20,
        "max_position_embeddings": 512,
        "min_length": 0,
        "model_type": "bert",
        "no_repeat_ngram_size": 0,
        "num_attention_heads": 12,
        "num_beams": 1,
        "num_hidden_layers": 12,
        "num_return_sequences": 1,
        "output_attentions": false,
        "output_hidden_states": false,
        "pad_token_id": 0,
        "prefix": null,
        "pruned_heads": {},
        "repetition_penalty": 1.0,
        "return_dict": false,
        "sep_token_id": null,
        "task_specific_params": null,
        "temperature": 1.0,
        "tie_encoder_decoder": false,
        "tie_word_embeddings": true,
        "tokenizer_class": null,
        "top_k": 50,
        "top_p": 1.0,
        "torchscript": false,
        "type_vocab_size": 2,
        "use_bfloat16": false,
        "use_cache": true,
        "vocab_size": 30522,
        "xla_device": null
      },
      "is_encoder_decoder": true,
      "model_type": "encoder_decoder"
    }

配置同样由编码器配置和解码器配置组成,两者都是BertConfig的实例。然而,整体配置是EncoderDecoderConfig类型,因此保存为单个.json文件。

总之,应该记住,一旦实例化了一个EncoderDecoderModel对象,它就提供了与🤗Transformers中的任何其他编码器-解码器模型相同的功能,例如BART,T5,ProphetNet等。唯一的区别是EncoderDecoderModel提供了额外的from_encoder_decoder_pretrained(...)函数,允许从任何两个编码器和解码器检查点热启动模型类。

另外需要注意的是,如果想创建一个共享的编码器-解码器模型,可以将参数tie_encoder_decoder=True传递如下:

shared_bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased", tie_encoder_decoder=True)

对比一下,我们可以看到绑定模型的参数比预期少得多。

print(f"\n\n绑定模型参数数量: {shared_bert2bert.num_parameters()}, 非绑定模型参数数量: {bert2bert.num_parameters()}")

输出:
-------
绑定模型参数数量: 137298244, 非绑定模型参数数量: 247363386

在这个笔记本中,我们将训练一个非绑定的Bert2Bert模型,所以我们继续使用bert2bert而不是shared_bert2bert

# 释放内存
del shared_bert2bert

我们已经热启动了一个bert2bert模型,但是我们还没有定义用于beam搜索解码的所有相关参数。

让我们从设置特殊标记开始。由于bert-base-cased没有decoder_start_token_ideos_token_id,所以我们将使用它的cls_token_idsep_token_id。此外,我们应该在配置中定义一个pad_token_id并确保设置了正确的vocab_size

bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size

接下来,让我们定义与beam搜索解码相关的所有参数。由于<bart-large-cnn在CNN/Dailymail上产生了良好的结果,我们将直接复制其beam搜索解码参数。

有关每个参数的详细信息,请参阅此博文或文档。

bert2bert.config.max_length = 142
bert2bert.config.min_length = 56
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

好了,现在让我们开始微调热启动的BERT2BERT模型。

微调热启动的编码器-解码器模型

在本节中,我们将展示如何使用Seq2SeqTrainer对热启动的编码器-解码器模型进行微调。

让我们首先导入Seq2SeqTrainer和它的训练参数Seq2SeqTrainingArguments

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

此外,我们还需要一些Python包来使Seq2SeqTrainer正常工作。

!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu

Seq2SeqTrainer扩展了🤗Transformers中用于编码器-解码器模型的Trainer。简而言之,它允许在评估过程中使用generate(...)函数,这对于验证编码器-解码器模型在大多数序列到序列任务(例如摘要)上的性能是必要的。

想要了解有关Trainer的更多信息,可以阅读这个简短的教程。

让我们从配置Seq2SeqTrainingArguments开始。

参数predict_with_generate应该设置为True,这样Seq2SeqTrainer在验证数据上运行generate(...),并将生成的输出作为predictions传递给稍后我们将定义的compute_metric(...)函数。其他的参数都是从TrainingArguments派生的,可以在这里阅读。对于完整的训练过程,应根据需要更改这些参数。以下是一些很好的默认值。

想要了解有关Seq2SeqTrainer的更多信息,建议查看代码。

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True, 
    output_dir="./",
    logging_steps=2,
    save_steps=10,
    eval_steps=4,
    # logging_steps=1000,
    # save_steps=500,
    # eval_steps=7500,
    # warmup_steps=2000,
    # save_total_limit=3,
)

此外,我们需要定义一个函数来正确计算验证期间的ROUGE分数。由于我们激活了predict_with_generatecompute_metrics(...)函数期望使用generate(...)函数获取的predictions。像大多数摘要任务一样,CNN/Dailymail通常使用ROUGE分数进行评估。

让我们首先使用🤗datasets库加载ROUGE指标。

rouge = datasets.load_metric("rouge")

接下来,我们将定义compute_metrics(...)函数。rouge指标计算来自两个字符串列表的分数。因此,我们解码predictionslabels两者-确保-100被正确替换为pad_token_id,并通过设置skip_special_tokens=True来删除所有特殊字符。

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

太棒了,现在我们可以将所有参数传递给Seq2SeqTrainer并开始微调。执行以下单元格将需要大约10分钟。

在完整的CNN/Dailymail训练数据上微调BERT2BERT模型需要大约8小时的时间,在一台TITAN RTX GPU上。

# 实例化训练器
trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)
trainer.train()

太棒了,现在我们应该已经完全准备好微调一个预热的编码器-解码器模型了。为了检查我们微调的结果,让我们查看保存的检查点。

!ls

OUTPUT:
-------
    bert2bert      checkpoint-20  runs     seq2seq_trainer.py
    checkpoint-10  __pycache__    sample_data  seq2seq_training_args.py

最后,我们可以像往常一样通过EncoderDecoderModel.from_pretrained(...)方法加载检查点。

dummy_bert2bert = EncoderDecoderModel.from_pretrained("./checkpoint-20")

评估

在最后一步,我们可能想要对测试数据上的BERT2BERT模型进行评估。

首先,我们不再加载虚拟模型,而是加载在完整训练数据集上进行微调的BERT2BERT模型。同时,我们加载它的分词器,该分词器只是bert-base-cased的一个副本。

from transformers import BertTokenizer

bert2bert = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail").to("cuda")
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

接下来,我们加载CNN/Dailymail测试数据的仅2%。对于完整的评估,应该使用100%的数据。

test_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="test[:2%]")

现在,我们可以再次利用🤗dataset的便利map()函数为每个测试样本生成摘要。

对于每个数据样本,我们:

  • 首先,对"article"进行分词,
  • 其次,生成输出的令牌ID,以及
  • 第三,解码输出的令牌ID,以获取我们预测的摘要。
def generate_summary(batch):
    # 在BERT最大长度512处截断
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    outputs = bert2bert.generate(input_ids, attention_mask=attention_mask)

    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred_summary"] = output_str

    return batch

让我们运行map函数以获取结果字典,其中存储了模型对每个样本的预测摘要。执行以下单元格可能需要大约10分钟 ☕。

batch_size = 16  # 对于完整的评估,改为64

results = test_data.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])

最后,我们计算ROUGE分数。

rouge.compute(predictions=results["pred_summary"], references=results["highlights"], rouge_types=["rouge2"])["rouge2"].mid

输出:
-------
    Score(precision=0.10389454113300968, recall=0.1564771201053348, fmeasure=0.12175271663717585)

就是这样。我们展示了如何对BERT2BERT模型进行预热启动,并在CNN/Dailymail数据集上进行微调/评估。

完全训练好的BERT2BERT模型已上传到🤗模型中心,名称为patrickvonplaten/bert2bert_cnn_daily_mail。

该模型在完整评估数据上的ROUGE-2分数为18.22,甚至比论文中报告的还要好一点。

对于一些摘要示例,建议读者使用该模型的在线推理API,[这里](链接)。

非常感谢Google Research的Sascha Rothe、Shashi Narayan和Aliaksei Severyn,以及🤗Hugging Face的Victor Sanh、Sylvain Gugger和Thomas Wolf对此进行校对并给予非常宝贵的反馈。

Leave a Reply

Your email address will not be published. Required fields are marked *