通过逐步解释操作和我们在PyTorch中的实现来直观解释NT-Xent损失
与Naresh Singh共同撰写。
介绍
自监督学习和对比学习的最新进展已经激发了机器学习(ML)领域的研究人员和从业者对这个领域的重新关注。
特别是,SimCLR论文提出了一种对视觉表征进行对比学习的简单框架,引起了自监督和对比学习领域的广泛关注。
这篇论文背后的核心思想非常简单——允许模型学习一对图像是否来自同一初始图像或者不同的初始图像。
SimCLR方法将每个输入图像i编码为特征向量zi。有两种情况需要考虑:
- 正向对:使用不同的一组增强技术增强相同的图像,并比较得到的特征向量zi和zj。这些特征向量被强制相似的损失函数。
- 负向对:使用不同的一组增强技术增强不同的图像,并比较得到的特征向量zi和zk。这些特征向量被强制不相似的损失函数。
本文的其余部分将重点介绍和理解这个损失函数,以及使用PyTorch进行高效实现。
NT-Xent损失
在高层次上,对比学习模型接收2N个图像,这些图像来源于N个底层图像。每个底层图像使用随机一组图像增强技术来增强,从而产生2个增强图像。这是我们在单个训练批次中得到2N个图像的方式。
在接下来的章节中,我们将深入研究NT-Xent损失的以下方面。
- 温度对SoftMax和Sigmoid的影响
- NT-Xent损失的简单直观解释
- PyTorch中NT-Xent的逐步实现
- 多标签损失函数(NT-BXent)的需求动机
- PyTorch中NT-BXent的逐步实现
步骤2-5的所有代码都可以在此notebook中找到。步骤1的代码可以在此notebook中找到。
温度对SoftMax和Sigmoid的影响
为了理解本文中要研究的对比损失函数的所有移动部分,我们首先需要了解温度对 SoftMax 和 Sigmoid 激活函数的影响。
通常,对 SoftMax 或 Sigmoid 的输入进行温度缩放,以平滑或强调这些激活函数的输出。输入的 logits 在传入激活函数之前被温度除以。您可以在此 notebook 中找到此部分的所有代码。
SoftMax :对于 SoftMax,高温度降低了输出分布的方差,导致标签变软。低温度增加了输出分布的方差,并使最大值在其他值上突出。当输入张量为 [0.1081, 0.4376, 0.7697, 0.1929, 0.3626, 2.8451] 时,请参见下面的图表以了解温度对 SoftMax 的影响。
Sigmoid :对于 Sigmoid,高温度会导致输出分布向 0.0 拉伸,而低温度会将输入拉伸到更高的值,使输出更接近于 0.0 或 1.0,具体取决于输入的无符号大小。
现在我们了解了不同温度值对 SoftMax 和 Sigmoid 函数的影响,让我们看看这如何适用于我们对 NT-Xent 损失的理解。
解释 NT-Xent 损失
通过理解本损失名称中的单个术语,可以理解 NT-Xent 损失。
- 规范化:余弦相似性在范围[-1.0,+1.0]内产生规范化得分
- 温度缩放:所有对之间的余弦相似性在计算交叉熵损失之前按比例缩放
- 交叉熵损失:底层损失是多类(单标签)交叉熵损失
如上所述,我们假设对于大小为 2N 的批次,以下索引处的特征向量表示正对 (0, 1),(2, 3),(4, 5),(6, 7)……,其余组合表示负对。这是在解释 NT-Xent 损失与 SimCLR 相关时需要记住的重要因素。
现在我们了解了 NT-Xent 损失中术语在上下文中的含义,让我们看一下计算特征向量批次的 NT-Xent 损失所需的机械步骤。
- SimCLR 模型生成的每个 2N 向量的所有对余弦相似性得分都计算出来。这导致表示为 2N x 2N 矩阵的 (2N)² 相似性得分
- 忽略相同值(i,i)之间的比较结果(因为分布与自身完全相似,不可能允许模型学习任何有用信息)
- 每个值(余弦相似性)都由温度参数 𝜏(一个超参数)缩放
- 对上面的结果矩阵的每行应用交叉熵损失。下面的段落更详细地解释了这一点
- 通常,这些损失的平均值(每个批次元素一个损失)用于反向传播
在这里使用交叉熵损失的方式在语义上略有不同,与标准分类任务中使用的方式略有不同。在分类任务中,训练最终的“分类头”以为每个输入生成一个 one-hot-probability 向量,并计算该 one-hot-probability 向量上的交叉熵损失,因为我们实际上在计算两个分布之间的差异。这个视频很好地解释了交叉熵损失的概念。在 NT-Xent 损失中,可训练层与输出分布之间没有 1:1 的对应关系。相反,为每个输入计算一个特征向量,然后计算每对特征向量之间的余弦相似性。这里的技巧在于,由于每个图像与输入批次中的另一个图像相似(正对)(如果我们忽略特征向量与自身的相似性),因此我们可以将其视为类似于分类的设置,其中图像之间的相似度概率分布表示一个分类任务,其中一个图像将接近于 1.0,其余图像将接近于 0.0。
现在我们已经对 NT-Xent 损失有了一个扎实的整体理解,我们应该可以很好地在 PyTorch 中实现这些想法。让我们开始吧!
在 PyTorch 中实现 NT-Xent 损失
本节中的所有代码都可以在此笔记本中找到。
代码重用:在线上看到的许多 NT-Xent 损失实现从头开始实现了所有操作。此外,其中一些实现效率低下,更喜欢使用循环而不是 GPU 并行性来实现损失函数。相反,我们将使用不同的方法。我们将按照 PyTorch 已经提供的标准交叉熵损失来实现此损失。为此,我们需要将预测和真实标签转换为可以接受 cross_entropy 的格式。让我们看看如何在下面执行此操作。
预测张量:首先,我们需要创建一个 PyTorch 张量,它将表示我们的对比学习模型的输出。假设我们的批量大小为 8 (2N=8),我们的特征向量具有 2 个维度 (2 个值)。我们将称我们的输入变量为“x”。
x = torch.randn(8, 2)
余弦相似度:接下来,我们将计算此批次中每个特征向量之间的所有对余弦相似度,并将结果存储在名为“xcs”的变量中。如果下面的行看起来令人困惑,请在此页面上阅读详细信息。这是“规范化”步骤。
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
如上所述,我们需要忽略每个特征向量的自相似性得分,因为它对模型的学习没有贡献,并且稍后在计算交叉熵损失时会成为一个不必要的麻烦。为此,我们将定义一个名为“eye”的变量,它是一个矩阵,其中主对角线上的元素具有值 1.0,其余元素为 0.0。我们可以使用以下命令创建这样的矩阵。
eye = torch.eye(8)
现在让我们将其转换为布尔矩阵,以便我们可以使用此掩码矩阵索引到“xcs”变量中。
eye = eye.bool()
让我们将张量“xcs”克隆到名为“y”的张量中,以便我们稍后可以引用“xcs”张量。
y = xcs.clone()
现在,我们将把所有对余弦相似度矩阵的主对角线上的值设置为 -inf,这样当我们在每行上计算 softmax 时,此值将不会有任何贡献。
y[eye] = float("-inf")
通过温度参数缩放的张量“y”将是 PyTorch 中交叉熵损失 API 的一个输入 (预测)。接下来,我们需要计算我们需要馈送到交叉熵损失 API 的真实标签 (目标)。
真实标签 (目标张量):对于我们使用的示例 (2N=8),目标张量应如下所示。
tensor([1, 0, 3, 2, 5, 4, 7, 6])
这是因为张量“y”中的以下索引对包含正对。
(0, 1), (1, 0)
(2, 3), (3, 2)
(4, 5), (5, 4)
(6, 7), (7, 6)
为了解释上面的索引对,我们查看单个示例。对 (4, 5) 的配对意味着行 4 中的列 5 应设置为 1.0 (正对),这也是上面的张量所表达的。太好了!
为了创建上面的张量,我们可以使用以下 PyTorch 代码,它将真实标签存储在变量“target”中。
target = torch.arange(8)target[0::2] += 1target[1::2] -= 1
交叉熵损失:我们已经有了计算损失所需的所有要素!唯一剩下的就是在 PyTorch 中调用交叉熵 API。
loss = F.cross_entropy(y / temperature, target, reduction="mean")
变量“loss”现在包含了计算出的 NT-Xent 损失。下面让我们把所有的代码放在一个 Python 函数中。
def nt_xent_loss(x, temperature): assert len(x.size()) == 2 # 余弦相似性 xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1) xcs[torch.eye(x.size(0)).bool()] = float("-inf") # 真实标签 target = torch.arange(8) target[0::2] += 1 target[1::2] -= 1 # 标准交叉熵损失 return F.cross_entropy(xcs / temperature, target, reduction="mean")
上述代码仅适用于在对比学习模型进行训练时,每个特征向量在批处理中恰好有一个正对(positive pair)的情况。接下来看一下如何处理对比学习任务中的多个正对(positive pairs)。
对比学习的多标签损失: NT-BXent
在 SimCLR 论文中,每个图像 i 在索引 j 处恰好有 1 个相似对。这使得交叉熵损失成为该任务的完美选择,因为它类似于多类别问题。但是,如果我们将同一图像的 M > 2 个增强版本馈送到对比学习模型的单个训练批次中,则每个批次将为图像 i 包含 M-1 个相似对。这个任务将类似于多标签问题。
显而易见的选择是将交叉熵损失替换为二元交叉熵损失。因此,它被称为规范化温度缩放的二元交叉熵损失(NT-BXent loss)。
下面的公式显示了元素 i 的损失 Li。公式中的 σ 表示 Sigmoid 函数。
为了避免类别不平衡问题,我们通过正负样本的倒数来加权正负对。用于反向传播的 mini-batch 的最终损失将是我们 mini-batch 中每个样本损失的平均值。
接下来,让我们关注一下在 PyTorch 中实现 NT-BXent 损失。
在 PyTorch 中实现 NT-BXent 损失
本节中的所有代码都可以在这个笔记本中找到。
代码复用:与我们实现 NT-Xent 损失的方法类似,我们将重用 PyTorch 提供的二元交叉熵(BCE)损失方法。我们设置的真实标签将类似于使用 BCE 损失的多标签分类问题。
预测张量:我们将使用与实现 NT-Xent 损失相同的(8,2)预测张量。
x = torch.randn(8, 2)
余弦相似度:由于输入张量 x 相同,所有对的余弦相似度张量 xcs 也将相同。请参阅此页面,了解下面这行代码的详细解释。
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
为了确保位置 (i, i) 处的元素损失为 0,我们需要对 xcs 张量执行某些操作,以使其在应用 Sigmoid 后在每个索引 (i, i) 处包含值 1。由于我们将使用 BCE Loss,因此我们将在张量 xcs 中将每个特征向量的自相似度得分标记为 infinity。这是因为将 Sigmoid 函数应用于 xcs 张量会将无穷大转换为值 1,我们将设置我们的真实标签,以使得地面真实标签中的每个位置 (i, i) 具有值 1。
让我们创建一个掩码张量,其在主对角线上具有值True(xcs在主对角线上具有自相似性分数),并且其他地方为False。
eye = torch.eye(8).bool()
让我们将张量“xcs”克隆到名为“y”的张量中,以便我们稍后可以引用“xcs”张量。
y = xcs.clone()
现在,我们将所有对余弦相似性矩阵的主对角线上的值设置为无穷大,以便在对每行进行sigmoid计算时,我们在这些位置得到1。
y[eye] = float("inf")
张量“y”乘以温度参数将是传递给PyTorch中BCE loss API的一组输入(预测)。接下来,我们需要计算我们需要传递给BCE loss API的ground-truth标签(目标)。
Ground Truth标签(目标张量):我们将期望用户向我们传递包含正例的所有(x,y)索引对。这与我们为NT-Xent损失所做的不同,因为正例是隐含的,而在这里,正例是显式的。
除了用户提供的位置外,我们还将所有对角线元素设置为正对,如上所述。我们将使用PyTorch张量索引API从这些位置拔出所有元素并将它们设置为1,而其余元素将初始化为0。
target = torch.zeros(8, 8)pos_indices = torch.tensor([ (0, 0), (0, 2), (0, 4), (1, 4), (1, 6), (1, 1), (2, 3), (3, 7), (4, 3), (7, 6),])# 将主对角线的索引添加为正索引。# 这很有用,因为我们将在PyTorch中使用BCELoss,# 它将期望主对角线上的元素有一个值。pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)# 将目标向量中的值设置为1。target[pos_indices[:,0], pos_indices[:,1]] = 1
二元交叉熵(BCE)损失:与NT-Xent损失不同,我们不能简单地调用torch.nn.functional.binary_cross_entropy_function,因为我们想根据当前小批次中元素i具有多少正例和负例对来加权正负损失。
不过,第一步是计算逐元素BCE损失。
temperature = 0.1loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")
我们将创建一个正负对的二元掩码,然后创建2个张量,loss_pos和loss_neg,它们仅包含与正负对应的计算损失的那些元素。
target_pos = target.bool()target_neg = ~target_pos# loss_pos和loss_neg以下仅包含正负对应的那些元素的非零值。loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
接下来,我们将分别对我们的小批次中每个元素i对应的正负对损失进行求和。
# loss_pos和loss_neg现在包含相对于i'th输入计算的正负对的总和。loss_pos = loss_pos.sum(dim=1)loss_neg = loss_neg.sum(dim=1)
要进行加权,我们需要跟踪与我们的小批次中每个元素i对应的正负对数。张量“num_pos”和“num_neg”将存储这些值。
# num_pos和num_neg以下包含相对于i'th输入计算的正负对的数量。在实际设置中,这个数字对于每个输入元素应该是相同的,但是我们在这里让它变化以获得最大的灵活性。num_pos = target.sum(dim=1)num_neg = target.size(0) - num_pos
我们拥有计算损失所需的所有元素!我们唯一需要做的就是通过正负对的数量对正负损失进行加权,然后平均每个小批量的损失。
def nt_bxent_loss(x, pos_indices, temperature): assert len(x.size()) == 2 # Add indexes of the principal diagonal elements to pos_indices pos_indices = torch.cat([ pos_indices, torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2), ], dim=0) # Ground truth labels target = torch.zeros(x.size(0), x.size(0)) target[pos_indices[:,0], pos_indices[:,1]] = 1.0 # Cosine similarity xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1) # Set logit of diagonal element to "inf" signifying complete # correlation. sigmoid(inf) = 1.0 so this will work out nicely # when computing the Binary cross-entropy Loss. xcs[torch.eye(x.size(0)).bool()] = float("inf") # Standard binary cross-entropy loss. We use binary_cross_entropy() here and not # binary_cross_entropy_with_logits() because of # https://github.com/pytorch/pytorch/issues/102894 # The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values # to result in a NaN result. loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none") target_pos = target.bool() target_neg = ~target_pos loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos]) loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg]) loss_pos = loss_pos.sum(dim=1) loss_neg = loss_neg.sum(dim=1) num_pos = target.sum(dim=1) num_neg = x.size(0) - num_pos return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()pos_indices = torch.tensor([ (0, 0), (0, 2), (0, 4), (1, 4), (1, 6), (1, 1), (2, 3), (3, 7), (4, 3), (7, 6),])for t in (0.01, 0.1, 1.0, 10.0, 20.0): print(f"温度: {t:5.2f}, 损失: {nt_bxent_loss(x, pos_indices, temperature=t)}")
打印输出结果如下。
温度: 0.01, 损失: 62.898780822753906
温度: 0.10, 损失: 4.851151943206787
温度: 1.00, 损失: 1.0727109909057617
温度: 10.00, 损失: 0.9827173948287964
温度: 20.00, 损失: 0.982099175453186
结论
自监督学习是深度学习中即将出现的领域,它允许我们在无标签数据上训练模型。这种技术使我们能够绕过大规模标记数据的要求。
在本文中,我们学习了对比学习的损失函数。第一个损失函数命名为NT-Xent loss,用于在小批量中每个输入的单个正对上进行学习。我们介绍了NT-BXent loss,它用于在小批量中每个输入的多个(>1)正对上进行学习。我们直观地解释了它们,借鉴了我们对交叉熵损失和二进制交叉熵损失的了解。最后,我们高效地在PyTorch中实现了它们。