Press "Enter" to skip to content

使用PyTorch进行高效图像分割:第4部分

基于Vision Transformer模型

在这个四部分的系列文章中,我们将使用PyTorch中的深度学习技术,从头开始逐步实现图像分割。本部分将重点介绍使用基于Vision Transformer的模型进行图像分割的实现。

与Naresh Singh共同编写

图1:使用Vision Transformer模型架构运行图像分割的结果。从上到下依次为输入图像、地面实况分割掩模和预测分割掩模。来源:作者

文章大纲

在本文中,我们将介绍Transformer架构,这个架构已经席卷了深度学习的世界。Transformer是一种多模态的架构,可以对不同的模态进行建模,例如语言、视觉和音频。

在本文中,我们将:

  1. 了解Transformer架构和涉及的关键概念
  2. 了解Vision Transformer架构
  3. 引入一个基于Vision Transformer模型的模型,从头开始编写,以便您可以欣赏所有构建块和移动部件
  4. 跟随输入张量进入该模型,并检查它如何改变形状
  5. 使用该模型在Oxford IIIT Pet数据集上执行图像分割
  6. 观察这个分割任务的结果
  7. 简要介绍SegFormer,这是一种用于语义分割的最先进的Vision Transformer

在整个文章中,我们将参考代码和结果,这些代码和结果来自于此笔记本电脑的模型训练。如果您希望重现结果,您需要一个GPU,以确保第一个笔记本电脑在合理的时间内完成运行。

本系列文章

本系列文章适合所有深度学习的读者。如果您想学习关于深度学习和视觉AI的实践,以及一些扎实的理论和实践经验,那么您来对地方了!这预计是一个四部分的系列,包括以下文章:

  1. 概念和思想
  2. 基于CNN的模型
  3. 深度可分离卷积
  4. 基于Vision Transformer的模型(本文)

让我们开始我们的Vision Transformer之旅,介绍Transformer架构的简介和直观理解。

Transformer架构

我们可以将Transformer架构看作是通信和计算的交错层组成。这个想法在图2中以图像的形式呈现。Transformer有N个处理单元(在图2中N为3),每个处理单元负责处理输入的1/N部分。为了使这些处理单元产生有意义的结果,每个处理单元都需要全局视图的输入。因此,系统会将有关每个处理单元中数据的信息反复传递给其他每个处理单元;使用从每个处理单元到每个其他处理单元的红色、绿色和蓝色箭头表示。然后进行一些基于这些信息的计算。经过足够多次的这个过程,模型能够产生所需的结果。

图2:Transformers中的交错通信和计算。该图显示了只有2个层的交错通信和计算。实际上,有许多这样的层。来源:作者

值得注意的是,大多数在线资源通常讨论Transformer的编码器和解码器,正如标题为“注意力全是你需要的”中所述。但是,在本文中,我们将仅描述Transformer的编码器部分。

让我们更详细地了解Transformers中构成通信和计算的部分。

Transformers中的通信:注意力

在Transformers中,通信是通过一种称为注意力层的层来实现的。在PyTorch中,这称为MultiHeadAttention。我们稍后会理解这个名称的原因。

文档中提到:

“允许模型联合关注来自不同表示子空间的信息,如论文中所述:Attention is all you need。”

注意机制消耗形状为(Batch,Length,Features)的输入张量x,并生成类似形状的张量y,以便根据张量正在关注的同一实例中的其他输入来更新每个输入的特征。因此,在长度为“特征”的每个张量的实例中,基于其他每个张量更新特征。这就是注意力机制的二次成本所在的地方。

图3:单词“it”相对于句子中其他单词的关注。我们可以看到“it”正在关注同一句子中的单词“animal”、“too”和“tire(d)”。来源:使用此colab生成。

在视觉变换器的上下文中,变换器的输入是一张图像。让我们将其假设为128 x 128(宽度,高度)的图像。我们将其分成多个大小为(16 x 16)的较小块。对于128 x 128的图像,我们得到64个块(长度),每行8个块,共8行块。

这64个大小为16 x 16像素的块中的每个块都被视为变换器模型的单独输入。不必深入细节,只需将此过程视为由64个不同的处理单元驱动,每个处理单元都在处理单个16×16图像块。

在每一轮中,每个处理单元中的注意机制负责查看其负责的图像块,并查询其余63个处理单元中的每一个,以询问它们可能与其自己的图像块有效处理有关的任何信息和有用信息。

通过关注进行的通信步骤后,是我们将要看到的计算。

变换器中的计算:多层感知器

变换器中的计算仅是一个多层感知器(MLP)单元。该单元由2个线性层组成,中间有一个GeLU非线性层。也可以考虑使用其他非线性。

该单元首先将输入投影到4倍大小,然后重新将其投影回1倍大小,这与输入大小相同。

在我们的笔记本中看到的代码中,此类被称为MultiLayerPerceptron。代码如下。

class MultiLayerPerceptron(nn.Sequential):    def __init__(self, embed_size, dropout):        super().__init__(            nn.Linear(embed_size, embed_size * 4),            nn.GELU(),            nn.Linear(embed_size * 4, embed_size),            nn.Dropout(p=dropout),        )    # end def# end class

现在我们了解了变换器体系结构的高级工作原理,让我们将注意力集中在视觉变换器上,因为我们将执行图像分割。

视觉变换器

视觉变换器首次由题为“图像价值16×16个字:用于规模图像识别的Transformers”的论文介绍。该论文讨论了作者如何将香草变换器体系结构应用于图像分类问题。这是通过将图像分成大小为16×16的块,并将每个块视为模型的一个输入令牌来完成的。变换器编码器模型被馈送这些输入令牌,并要求为输入图像预测一个类。

图4:来源:用于规模图像识别的Transformers。

在我们的情况下,我们对图像分割感兴趣。我们可以将其视为像素级分类任务,因为我们打算为每个像素预测一个目标类别。

我们对基础视觉变换器进行了一个小但重要的改变,将用于分类的MLP头替换为用于像素级分类的MLP头。我们在输出中只有一个线性层,这个线性层被所有的补丁共享,这些补丁的分割掩模由视觉变换器预测。这个共享的线性层为每个作为输入发送到模型的补丁预测一个分割掩模。

在视觉变换器的情况下,大小为16×16的补丁被认为相当于特定时间步长的单个输入令牌。

图5:用于图像分割的视觉变换器的端到端工作。使用此笔记本生成的图像。来源:作者。

建立视觉变换器中张量维度的直觉

在使用深度CNN时,我们主要使用的张量维度为(N,C H,W),其中字母表示以下内容:

  • N:批量大小
  • C:通道数
  • H:高度
  • W:宽度

您可以看到,这种格式针对二维图像处理,因为它具有非常特定于图像的特征。

另一方面,使用变换器的情况则变得更加通用和领域无关。我们将在下面看到的内容适用于视觉、文本、NLP、音频或其他可以将输入数据表示为序列的问题。值得注意的是,当张量通过我们的视觉变换器流动时,它们的表示中几乎没有视觉特定的偏差。

在使用变换器和注意力的情况下,我们希望张量具有以下形状:(B,T,C),其中字母表示以下内容:

  • B:批量大小(与CNN相同)
  • T:时间维度或序列长度。此维度有时也称为L。在视觉变换器的情况下,每个图像补丁对应于此维度。如果我们有16个图像补丁,则T维度的值将为16
  • C:通道或嵌入大小维度。此维度有时也称为E。在处理图像时,每个大小为3x16x16(通道、宽度、高度)的补丁通过补丁嵌入层映射到大小为C的嵌入。稍后我们将看到如何完成此操作。

让我们深入了解输入图像张量在预测分割掩模的过程中如何发生变化和处理。

视觉变换器中张量的旅程

在深度CNN中,张量的旅程看起来像这样(在UNet、SegNet或其他基于CNN的体系结构中)。

输入张量通常是形状为(1、3、128、128)的。这个张量经过一系列卷积和最大池化操作,其中它的空间维度被缩小,通道维度通常每次增加2倍。这称为特征编码器。在此之后,我们执行相反的操作,其中我们增加空间维度并减少通道维度。这称为特征解码器。在解码过程之后,我们得到形状为(1、64、128、128)的张量。然后使用1×1无偏差卷积将其投影到我们所需的输出通道数C中,形状为(1、C、128、128)。

图6:用于图像分割的深度CNN中张量形状的典型进展。来源:作者。

与视觉变换器相比,流程要复杂得多。让我们看一下下面的图像,然后试着理解张量如何在每一步中转换形状。

图7:用于图像分割的视觉变换器中张量形状的典型进展。来源:作者。

让我们更详细地看一下每个步骤,以及它如何更新通过视觉变换器流动的张量的形状。为了更好地理解这一点,让我们为我们的张量维度取具体值。

  1. 批量归一化:输入和输出张量的形状为(1,3,128,128)。形状不变,但值被归一化为零均值和单位方差。
  2. 图像转换为补丁:形状为(1,3,128,128)的输入张量被转换为16×16图像的堆叠补丁。输出张量的形状为(1,64,768)。
  3. 补丁嵌入:补丁嵌入层将768个输入通道映射到512个嵌入通道(针对此示例)。输出张量的形状为(1,64,512)。补丁嵌入层基本上只是PyTorch中的一个nn.Linear层。
  4. 位置嵌入:位置嵌入层没有输入张量,但有效地贡献了一个可学习的参数(在PyTorch中是可训练的张量),其形状与补丁嵌入相同。这是形状为(1,64,512)的。
  5. 加法:补丁和位置嵌入被逐个相加以产生我们的视觉变换器编码器的输入。该张量的形状为(1,64,512)。您会注意到,视觉变换器的主要工作部分,即编码器基本上使此张量形状保持不变。
  6. 变压器编码器:形状为(1,64,512)的输入张量通过多个变压器编码器块流动,每个块都有多个注意头(通信),然后是一个MLP层(计算)。张量形状保持不变,为(1,64,512)。
  7. 线性输出投影:如果我们假设我们想将每个图像分割为10个类,则需要使大小为16×16的每个补丁具有10个通道。输出投影的nn.Linear层现在将512个嵌入通道转换为16x16x10 = 2560个输出通道,这个张量看起来像(1,64,2560)。在上面的图表中,C’=10。理想情况下,这将是一个多层感知器,因为“ MLP是通用函数逼近器”,但由于这是一个教育性练习,我们使用一个单一线性层。
  8. 补丁到图像:此层将编码为(1,64,2560)张量的64个补丁转换回类似于分割掩模的东西。这可以是10个单通道图像,或在这种情况下是一个10通道图像,其中每个通道都是10个类别之一的分割掩模。输出张量的形状为(1,10,128,128)。

这就是全部——我们已成功使用视觉变换器对输入图像进行了分割!接下来,让我们看一下实验以及一些结果。

视觉变换器的应用

此笔记本包含此部分的所有代码。

就代码和类结构而言,它与上面的块图基本相似。此笔记本中的大多数概念都与本文中提及的概念具有1:1的对应关系。

与注意层相关的一些概念是我们模型的关键超参数。我们之前没有提及多头注意力的详细信息,因为我们提到它超出了本文的范围。如果您对变换器中的注意机制没有基本了解,我们强烈建议在继续之前阅读上面提到的参考资料。

我们为分割的视觉变换器使用了以下模型参数。

  1. PatchEmbedding层的768个嵌入维度
  2. 12个变压器编码器块
  3. 每个变压器编码器块中有8个注意头
  4. 多头注意力和MLP中的20%丢失

可以在VisionTransformerArgs Python数据类中看到此配置。

@dataclassclass VisionTransformerArgs:    """VisionTransformerForSegmentation的参数。"""    image_size: int = 128    patch_size: int = 16    in_channels: int = 3    out_channels: int = 3    embed_size: int = 768    num_blocks: int = 12    num_heads: int = 8    dropout: float = 0.2# end class

在模型训练和验证期间,使用了与之前类似的配置。具体配置如下:

  1. 对训练集应用随机水平翻转和颜色抖动数据增强以防止过拟合
  2. 将图像调整为128×128像素,进行非宽高比保持的调整操作
  3. 不对图像应用输入归一化,而是将批归一化层用作模型的第一层
  4. 使用Adam优化器进行50个epochs的训练,学习率为0.0004,使用StepLR调度程序,在每12个epochs时将学习率降低0.8倍
  5. 使用交叉熵损失函数将像素分类为宠物、背景或宠物边框

该模型具有86.28M参数,在50个训练epochs后实现了85.89%的验证准确度。这比深度CNN模型在20个训练epochs后达到的88.28%的准确度要低。这可能是由于需要通过实验验证的一些因素。

  1. 最后一个输出投影层是单个nn.Linear而不是多层感知器
  2. 16×16的补丁大小过大,无法捕捉更细致的细节
  3. 没有足够的训练epochs
  4. 训练数据不足——已知与深度CNN模型相比,transformer模型需要更多的数据进行有效训练
  5. 学习率太低

我们绘制了一个gif,展示了该模型如何学习预测验证集中21张图像的分割掩模。

Figure 8: A gif showing the progression of segmentation masks predicted by the vision transformer for image segmentation model. Source: Author(s).

我们在早期的训练epochs中注意到了一些有趣的现象。预测的分割掩模具有一些奇怪的阻塞伪影。唯一我们能想到的原因是,我们将图像分解为大小为16×16的补丁,在极少数的训练epochs之后,该模型没有学会任何有用的信息,仅了解到该16×16补丁通常被宠物或背景像素覆盖。

Figure 9: The blocking artifacts seen in the predicted segmentation masks when using the vision transformer for image segmentation. Source: Author(s).

现在我们已经看到了基本的视觉transformer的工作原理,让我们把注意力转向用于分割任务的最先进的视觉transformer。

SegFormer:具有transformer的语义分割

SegFormer体系结构是在2021年提出的。我们之前看到的transformer是SegFormer体系结构的简化版本。

Figure 10: The SegFormer architecture. Source: SegFormer paper (2021) .

最值得注意的是,SegFormer:

  1. 生成4组带有4×4、8×8、16×16和32×32大小补丁的图像集,而不是带有16×16大小补丁的单个补丁图像
  2. 使用4个transformer编码器块,而不是仅使用1个。这感觉像是一个模型合集
  3. 在自我注意力的前后阶段使用卷积
  4. 不使用位置嵌入
  5. 每个transformer块在空间分辨率H/4×W/4、H/8×W/8、H/16×W/16和H/32、W/32处处理图像
  6. 类似地,当空间尺寸减小时,通道数量增加。这感觉类似于深度CNN
  7. 在多个空间尺度上的预测被上采样,然后在解码器中合并在一起
  8. 一个MLP结合所有这些预测提供最终预测
  9. 最终预测在空间维度H/4,W/4而不是在H,W处。

结论

在本系列的第四部分中,我们介绍了Transformer架构和视觉Transformer。我们对视觉Transformer的工作原理有了直观的理解,并了解了视觉Transformer在通信和计算阶段涉及的基本构建块。我们看到了视觉Transformer采用的独特的基于补丁的方法,用于预测分割掩模,然后将预测组合在一起。

我们回顾了一项展示视觉Transformer的实验,并能够与深度CNN方法进行比较。虽然我们的视觉Transformer不是最先进的,但它能够取得相当不错的结果。我们提供了一瞥最先进的方法,例如SegFormer。

现在应该清楚,与基于深度CNN的方法相比,Transformer具有更多的移动部件和更复杂的结构。从原始的FLOPs角度来看,Transformer具有更高的效率潜力。在Transformer中,唯一真正计算复杂的层是nn.Linear。这在大多数架构上使用优化的矩阵乘法来实现。由于这种结构上的简单性,与基于深度CNN的方法相比,Transformer具有更易于优化和加速的潜力。

恭喜您走到了这一步!我们很高兴您喜欢阅读这篇有关PyTorch中高效图像分割的系列文章。如果您有问题或意见,请随时在评论区留言。

进一步阅读

本文不涉及注意力机制的细节。此外,有许多高质量的资源可供您参考,以了解注意力机制的详细信息。以下是我们强烈推荐的一些资源。

  1. Illustrated Transformer
  2. NanoGPT from scratch using PyTorch

我们将提供以下文章的链接,其中提供了有关视觉Transformer的更多详细信息。

  1. 在PyTorch中实现Vision Transformer(ViT):本文详细介绍了在PyTorch中实现用于图像分类的视觉Transformer。值得注意的是,他们的实现使用了einops,而我们避免使用它,因为这是一个面向教育的练习(我们建议学习和使用einops来提高代码可读性)。我们使用本机PyTorch运算符来排列和重新排列张量维度。此外,作者有时使用Conv2d而不是Linear层。我们想建立一个完全不使用卷积层的视觉Transformer的实现。
  2. Vision Transformer:AI Summer
  3. 在PyTorch中实现SegFormer
Leave a Reply

Your email address will not be published. Required fields are marked *