Press "Enter" to skip to content

NT-Xent(归一化温度缩放的交叉熵损失)损失的解释和在PyTorch中的实现

通过逐步解释操作和我们在PyTorch中的实现来直观解释NT-Xent损失

与Naresh Singh共同撰写。

NT-Xent损失公式。来源:Papers with code (CC-BY-SA)

介绍

自监督学习和对比学习的最新进展已经激发了机器学习(ML)领域的研究人员和从业者对这个领域的重新关注。

特别是,SimCLR论文提出了一种对视觉表征进行对比学习的简单框架,引起了自监督和对比学习领域的广泛关注。

这篇论文背后的核心思想非常简单——允许模型学习一对图像是否来自同一初始图像或者不同的初始图像。

图1:SimCLR背后的高层思想。来源:SimCLR论文

SimCLR方法将每个输入图像i编码为特征向量zi。有两种情况需要考虑:

  1. 正向对:使用不同的一组增强技术增强相同的图像,并比较得到的特征向量zizj。这些特征向量被强制相似的损失函数。
  2. 负向对:使用不同的一组增强技术增强不同的图像,并比较得到的特征向量zizk。这些特征向量被强制不相似的损失函数。

本文的其余部分将重点介绍和理解这个损失函数,以及使用PyTorch进行高效实现。

NT-Xent损失

在高层次上,对比学习模型接收2N个图像,这些图像来源于N个底层图像。每个底层图像使用随机一组图像增强技术来增强,从而产生2个增强图像。这是我们在单个训练批次中得到2N个图像的方式。

图2:对比学习的单个训练批次中的6个图像。当输入到对比学习模型时,每个图像下面的数字是该图像在输入批次中的索引。图片来源:牛津视觉几何组(CC-SA)

在接下来的章节中,我们将深入研究NT-Xent损失的以下方面。

  1. 温度对SoftMax和Sigmoid的影响
  2. NT-Xent损失的简单直观解释
  3. PyTorch中NT-Xent的逐步实现
  4. 多标签损失函数(NT-BXent)的需求动机
  5. PyTorch中NT-BXent的逐步实现

步骤2-5的所有代码都可以在此notebook中找到。步骤1的代码可以在此notebook中找到。

温度对SoftMax和Sigmoid的影响

为了理解本文中要研究的对比损失函数的所有移动部分,我们首先需要了解温度对 SoftMax 和 Sigmoid 激活函数的影响。

通常,对 SoftMax 或 Sigmoid 的输入进行温度缩放,以平滑或强调这些激活函数的输出。输入的 logits 在传入激活函数之前被温度除以。您可以在此 notebook 中找到此部分的所有代码。

SoftMax :对于 SoftMax,高温度降低了输出分布的方差,导致标签变软。低温度增加了输出分布的方差,并使最大值在其他值上突出。当输入张量为 [0.1081, 0.4376, 0.7697, 0.1929, 0.3626, 2.8451] 时,请参见下面的图表以了解温度对 SoftMax 的影响。

Figure 3: Effect of temperature on SoftMax. Source: Author(s)

Sigmoid :对于 Sigmoid,高温度会导致输出分布向 0.0 拉伸,而低温度会将输入拉伸到更高的值,使输出更接近于 0.0 或 1.0,具体取决于输入的无符号大小。

Figure 4: Effect of temperature on Sigmoid. Source: Author(s)

现在我们了解了不同温度值对 SoftMax 和 Sigmoid 函数的影响,让我们看看这如何适用于我们对 NT-Xent 损失的理解。

解释 NT-Xent 损失

通过理解本损失名称中的单个术语,可以理解 NT-Xent 损失。

  1. 规范化:余弦相似性在范围[-1.0,+1.0]内产生规范化得分
  2. 温度缩放:所有对之间的余弦相似性在计算交叉熵损失之前按比例缩放
  3. 交叉熵损失:底层损失是多类(单标签)交叉熵损失

如上所述,我们假设对于大小为 2N 的批次,以下索引处的特征向量表示正对 (0, 1),(2, 3),(4, 5),(6, 7)……,其余组合表示负对。这是在解释 NT-Xent 损失与 SimCLR 相关时需要记住的重要因素。

现在我们了解了 NT-Xent 损失中术语在上下文中的含义,让我们看一下计算特征向量批次的 NT-Xent 损失所需的机械步骤。

  1. SimCLR 模型生成的每个 2N 向量的所有对余弦相似性得分都计算出来。这导致表示为 2N x 2N 矩阵的 (2N)² 相似性得分
  2. 忽略相同值(i,i)之间的比较结果(因为分布与自身完全相似,不可能允许模型学习任何有用信息)
  3. 每个值(余弦相似性)都由温度参数 𝜏(一个超参数)缩放
  4. 对上面的结果矩阵的每行应用交叉熵损失。下面的段落更详细地解释了这一点
  5. 通常,这些损失的平均值(每个批次元素一个损失)用于反向传播

在这里使用交叉熵损失的方式在语义上略有不同,与标准分类任务中使用的方式略有不同。在分类任务中,训练最终的“分类头”以为每个输入生成一个 one-hot-probability 向量,并计算该 one-hot-probability 向量上的交叉熵损失,因为我们实际上在计算两个分布之间的差异。这个视频很好地解释了交叉熵损失的概念。在 NT-Xent 损失中,可训练层与输出分布之间没有 1:1 的对应关系。相反,为每个输入计算一个特征向量,然后计算每对特征向量之间的余弦相似性。这里的技巧在于,由于每个图像与输入批次中的另一个图像相似(正对)(如果我们忽略特征向量与自身的相似性),因此我们可以将其视为类似于分类的设置,其中图像之间的相似度概率分布表示一个分类任务,其中一个图像将接近于 1.0,其余图像将接近于 0.0。

现在我们已经对 NT-Xent 损失有了一个扎实的整体理解,我们应该可以很好地在 PyTorch 中实现这些想法。让我们开始吧!

在 PyTorch 中实现 NT-Xent 损失

本节中的所有代码都可以在此笔记本中找到。

代码重用:在线上看到的许多 NT-Xent 损失实现从头开始实现了所有操作。此外,其中一些实现效率低下,更喜欢使用循环而不是 GPU 并行性来实现损失函数。相反,我们将使用不同的方法。我们将按照 PyTorch 已经提供的标准交叉熵损失来实现此损失。为此,我们需要将预测和真实标签转换为可以接受 cross_entropy 的格式。让我们看看如何在下面执行此操作。

预测张量:首先,我们需要创建一个 PyTorch 张量,它将表示我们的对比学习模型的输出。假设我们的批量大小为 8 (2N=8),我们的特征向量具有 2 个维度 (2 个值)。我们将称我们的输入变量为“x”。

x = torch.randn(8, 2)

余弦相似度:接下来,我们将计算此批次中每个特征向量之间的所有对余弦相似度,并将结果存储在名为“xcs”的变量中。如果下面的行看起来令人困惑,请在此页面上阅读详细信息。这是“规范化”步骤。

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

如上所述,我们需要忽略每个特征向量的自相似性得分,因为它对模型的学习没有贡献,并且稍后在计算交叉熵损失时会成为一个不必要的麻烦。为此,我们将定义一个名为“eye”的变量,它是一个矩阵,其中主对角线上的元素具有值 1.0,其余元素为 0.0。我们可以使用以下命令创建这样的矩阵。

eye = torch.eye(8)

现在让我们将其转换为布尔矩阵,以便我们可以使用此掩码矩阵索引到“xcs”变量中。

eye = eye.bool()

让我们将张量“xcs”克隆到名为“y”的张量中,以便我们稍后可以引用“xcs”张量。

y = xcs.clone()

现在,我们将把所有对余弦相似度矩阵的主对角线上的值设置为 -inf,这样当我们在每行上计算 softmax 时,此值将不会有任何贡献。

y[eye] = float("-inf")

通过温度参数缩放的张量“y”将是 PyTorch 中交叉熵损失 API 的一个输入 (预测)。接下来,我们需要计算我们需要馈送到交叉熵损失 API 的真实标签 (目标)。

真实标签 (目标张量):对于我们使用的示例 (2N=8),目标张量应如下所示。

tensor([1, 0, 3, 2, 5, 4, 7, 6])

这是因为张量“y”中的以下索引对包含正对。

(0, 1), (1, 0)

(2, 3), (3, 2)

(4, 5), (5, 4)

(6, 7), (7, 6)

为了解释上面的索引对,我们查看单个示例。对 (4, 5) 的配对意味着行 4 中的列 5 应设置为 1.0 (正对),这也是上面的张量所表达的。太好了!

为了创建上面的张量,我们可以使用以下 PyTorch 代码,它将真实标签存储在变量“target”中。

target = torch.arange(8)target[0::2] += 1target[1::2] -= 1

交叉熵损失:我们已经有了计算损失所需的所有要素!唯一剩下的就是在 PyTorch 中调用交叉熵 API。

loss = F.cross_entropy(y / temperature, target, reduction="mean")

变量“loss”现在包含了计算出的 NT-Xent 损失。下面让我们把所有的代码放在一个 Python 函数中。

def nt_xent_loss(x, temperature):  assert len(x.size()) == 2  # 余弦相似性  xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)  xcs[torch.eye(x.size(0)).bool()] = float("-inf")  # 真实标签  target = torch.arange(8)  target[0::2] += 1  target[1::2] -= 1  # 标准交叉熵损失  return F.cross_entropy(xcs / temperature, target, reduction="mean")

上述代码仅适用于在对比学习模型进行训练时,每个特征向量在批处理中恰好有一个正对(positive pair)的情况。接下来看一下如何处理对比学习任务中的多个正对(positive pairs)。

对比学习的多标签损失: NT-BXent

在 SimCLR 论文中,每个图像 i 在索引 j 处恰好有 1 个相似对。这使得交叉熵损失成为该任务的完美选择,因为它类似于多类别问题。但是,如果我们将同一图像的 M > 2 个增强版本馈送到对比学习模型的单个训练批次中,则每个批次将为图像 i 包含 M-1 个相似对。这个任务将类似于多标签问题。

显而易见的选择是将交叉熵损失替换为二元交叉熵损失。因此,它被称为规范化温度缩放的二元交叉熵损失(NT-BXent loss)。

下面的公式显示了元素 i 的损失 Li。公式中的 σ 表示 Sigmoid 函数。

Figure 5: Formulation for the NT-BXent loss. Image source: Author(s) of this article

为了避免类别不平衡问题,我们通过正负样本的倒数来加权正负对。用于反向传播的 mini-batch 的最终损失将是我们 mini-batch 中每个样本损失的平均值。

接下来,让我们关注一下在 PyTorch 中实现 NT-BXent 损失。

在 PyTorch 中实现 NT-BXent 损失

本节中的所有代码都可以在这个笔记本中找到。

代码复用:与我们实现 NT-Xent 损失的方法类似,我们将重用 PyTorch 提供的二元交叉熵(BCE)损失方法。我们设置的真实标签将类似于使用 BCE 损失的多标签分类问题。

预测张量:我们将使用与实现 NT-Xent 损失相同的(8,2)预测张量。

x = torch.randn(8, 2)

余弦相似度:由于输入张量 x 相同,所有对的余弦相似度张量 xcs 也将相同。请参阅此页面,了解下面这行代码的详细解释。

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

为了确保位置 (i, i) 处的元素损失为 0,我们需要对 xcs 张量执行某些操作,以使其在应用 Sigmoid 后在每个索引 (i, i) 处包含值 1。由于我们将使用 BCE Loss,因此我们将在张量 xcs 中将每个特征向量的自相似度得分标记为 infinity。这是因为将 Sigmoid 函数应用于 xcs 张量会将无穷大转换为值 1,我们将设置我们的真实标签,以使得地面真实标签中的每个位置 (i, i) 具有值 1。

让我们创建一个掩码张量,其在主对角线上具有值Truexcs在主对角线上具有自相似性分数),并且其他地方为False

eye = torch.eye(8).bool()

让我们将张量“xcs”克隆到名为“y”的张量中,以便我们稍后可以引用“xcs”张量。

y = xcs.clone()

现在,我们将所有对余弦相似性矩阵的主对角线上的值设置为无穷大,以便在对每行进行sigmoid计算时,我们在这些位置得到1。

y[eye] = float("inf")

张量“y”乘以温度参数将是传递给PyTorch中BCE loss API的一组输入(预测)。接下来,我们需要计算我们需要传递给BCE loss API的ground-truth标签(目标)。

Ground Truth标签(目标张量):我们将期望用户向我们传递包含正例的所有(x,y)索引对。这与我们为NT-Xent损失所做的不同,因为正例是隐含的,而在这里,正例是显式的。

除了用户提供的位置外,我们还将所有对角线元素设置为正对,如上所述。我们将使用PyTorch张量索引API从这些位置拔出所有元素并将它们设置为1,而其余元素将初始化为0。

target = torch.zeros(8, 8)pos_indices = torch.tensor([  (0, 0), (0, 2), (0, 4),  (1, 4), (1, 6), (1, 1),  (2, 3),  (3, 7),  (4, 3),  (7, 6),])# 将主对角线的索引添加为正索引。# 这很有用,因为我们将在PyTorch中使用BCELoss,# 它将期望主对角线上的元素有一个值。pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)# 将目标向量中的值设置为1。target[pos_indices[:,0], pos_indices[:,1]] = 1

二元交叉熵(BCE)损失:与NT-Xent损失不同,我们不能简单地调用torch.nn.functional.binary_cross_entropy_function,因为我们想根据当前小批次中元素i具有多少正例和负例对来加权正负损失。

不过,第一步是计算逐元素BCE损失。

temperature = 0.1loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")

我们将创建一个正负对的二元掩码,然后创建2个张量,loss_pos和loss_neg,它们仅包含与正负对应的计算损失的那些元素。

target_pos = target.bool()target_neg = ~target_pos# loss_pos和loss_neg以下仅包含正负对应的那些元素的非零值。loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])

接下来,我们将分别对我们的小批次中每个元素i对应的正负对损失进行求和。

# loss_pos和loss_neg现在包含相对于i'th输入计算的正负对的总和。loss_pos = loss_pos.sum(dim=1)loss_neg = loss_neg.sum(dim=1)

要进行加权,我们需要跟踪与我们的小批次中每个元素i对应的正负对数。张量“num_pos”和“num_neg”将存储这些值。

# num_pos和num_neg以下包含相对于i'th输入计算的正负对的数量。在实际设置中,这个数字对于每个输入元素应该是相同的,但是我们在这里让它变化以获得最大的灵活性。num_pos = target.sum(dim=1)num_neg = target.size(0) - num_pos

我们拥有计算损失所需的所有元素!我们唯一需要做的就是通过正负对的数量对正负损失进行加权,然后平均每个小批量的损失。

def nt_bxent_loss(x, pos_indices, temperature):    assert len(x.size()) == 2    # Add indexes of the principal diagonal elements to pos_indices    pos_indices = torch.cat([        pos_indices,        torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),    ], dim=0)        # Ground truth labels    target = torch.zeros(x.size(0), x.size(0))    target[pos_indices[:,0], pos_indices[:,1]] = 1.0    # Cosine similarity    xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)    # Set logit of diagonal element to "inf" signifying complete    # correlation. sigmoid(inf) = 1.0 so this will work out nicely    # when computing the Binary cross-entropy Loss.    xcs[torch.eye(x.size(0)).bool()] = float("inf")    # Standard binary cross-entropy loss. We use binary_cross_entropy() here and not    # binary_cross_entropy_with_logits() because of    # https://github.com/pytorch/pytorch/issues/102894    # The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values    # to result in a NaN result.    loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")        target_pos = target.bool()    target_neg = ~target_pos        loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])    loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])    loss_pos = loss_pos.sum(dim=1)    loss_neg = loss_neg.sum(dim=1)    num_pos = target.sum(dim=1)    num_neg = x.size(0) - num_pos    return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()pos_indices = torch.tensor([    (0, 0), (0, 2), (0, 4),    (1, 4), (1, 6), (1, 1),    (2, 3),    (3, 7),    (4, 3),    (7, 6),])for t in (0.01, 0.1, 1.0, 10.0, 20.0):    print(f"温度: {t:5.2f}, 损失: {nt_bxent_loss(x, pos_indices, temperature=t)}")

打印输出结果如下。

温度: 0.01, 损失: 62.898780822753906

温度: 0.10, 损失: 4.851151943206787

温度: 1.00, 损失: 1.0727109909057617

温度: 10.00, 损失: 0.9827173948287964

温度: 20.00, 损失: 0.982099175453186

结论

自监督学习是深度学习中即将出现的领域,它允许我们在无标签数据上训练模型。这种技术使我们能够绕过大规模标记数据的要求。

在本文中,我们学习了对比学习的损失函数。第一个损失函数命名为NT-Xent loss,用于在小批量中每个输入的单个正对上进行学习。我们介绍了NT-BXent loss,它用于在小批量中每个输入的多个(>1)正对上进行学习。我们直观地解释了它们,借鉴了我们对交叉熵损失和二进制交叉熵损失的了解。最后,我们高效地在PyTorch中实现了它们。

Leave a Reply

Your email address will not be published. Required fields are marked *