Press "Enter" to skip to content

如何使用对抗数据动态训练你的模型

你将在这里学到什么
  • 💡动态对抗数据收集的基本思想以及其重要性。
  • ⚒如何动态收集对抗数据并在其上训练模型 – 以MNIST手写数字识别任务为例。

动态对抗数据收集(DADC)

静态基准在评估模型性能时被广泛使用,但存在许多问题:它们容易饱和、存在偏见或漏洞,而且常常导致研究人员追求指标的增长,而不是构建可信赖的模型,可以被人类使用1。

动态对抗数据收集(DADC)作为一种缓解静态基准问题的方法具有很大的潜力。在DADC中,人类创造了一些例子来欺骗最先进的(SOTA)模型。这个过程有两个好处:

  1. 它允许用户评估他们的模型的鲁棒性如何;
  2. 它产生的数据可以用于进一步训练更强大的模型。

通过在对抗性收集的数据上欺骗并训练模型,不断重复这个过程,可以得到与人类一致的更强大的模型1。

使用对抗数据动态训练模型

在这里,我将指导您从用户动态收集对抗数据并在其上训练模型 – 使用MNIST手写数字识别任务。

在MNIST手写数字识别任务中,模型被训练以在给定手写数字(见下图中的示例)的28x28灰度图像输入下预测数字。数字的范围是从0到9。

如何使用对抗数据动态训练你的模型 四海 第1张

图片来源:mnist | Tensorflow数据集

这个任务被广泛认为是计算机视觉的入门,很容易训练出在标准(静态)基准测试集上达到高准确率的模型。然而,研究表明,这些SOTA模型在人类书写数字时(并将其作为输入提供给模型)仍然很难预测出正确的数字:研究人员认为,这主要是因为静态测试集不足以充分代表人类书写的多样性方式。因此,需要人类参与,提供对抗样本,帮助模型更好地泛化。

本指南将分为以下几个部分:

  1. 配置您的模型
  2. 与您的模型交互
  3. 标记您的模型
  4. 将所有内容组合在一起

配置您的模型

首先,您需要定义您的模型架构。我的简单模型架构如下,由两个卷积网络连接到一个50维的全连接层和一个最终的10类层组成。最后,我们使用softmax激活函数将模型的输出转化为类别的概率分布。

# 代码来源:https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
    def __init__(self):
        super(MNIST_Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

现在您已经定义了模型的结构,您需要在标准的MNIST训练/开发数据集上对其进行训练。

与您的模型交互

在这一点上,我们假设您已经训练好了模型。虽然这个模型已经训练好了,但我们的目标是使用人类参与的对抗性数据使其更加健壮。为此,您需要一种让用户与之交互的方式:具体来说,您希望用户能够在画布上书写/绘制0-9的数字,并让模型尝试对其进行分类。您可以使用🤗 Spaces来完成所有这些,它允许您快速轻松地为您的ML模型构建演示。在这里了解更多关于Spaces以及如何构建它们的信息。

下面是一个简单的空间,可以与我训练了20个周期的MNIST_Model进行交互(在测试集上达到89%的准确率)。你可以在白色画布上画一个数字,模型会根据你的图像预测出数字。完整的空间可以在这里访问。试着欺骗这个模型😁。用你最有趣的手写体写字;在画布的边缘写字;尽情发挥吧!

标记你的模型

你能成功欺骗上面的模型吗?😀 如果是的话,那么现在是时候标记你的对抗样本了。标记的步骤包括:

  1. 将对抗样本保存到数据集中
  2. 在收集到一定数量的样本后,对模型进行对抗样本的训练
  3. 重复步骤1-2多次

我编写了一个自定义的flag函数来完成所有这些操作。如果想了解更多细节,请随意查看完整的代码。

注意:Gradio有一个内置的标记回调函数,可以方便地标记你的模型的对抗样本。想了解更多信息,请阅读这里。

将所有组件整合在一起

最后一步是将所有三个组件(配置模型、与模型交互和标记模型)整合成一个演示空间!为此,我创建了用于动态收集MNIST手写识别任务的对抗数据的MNIST对抗空间。请随意在下面进行测试。

结论

动态对抗数据收集(DADC)作为一种收集多样化非饱和人类对齐数据集并改进模型评估和任务性能的方式,在机器学习社区中越来越受到关注。通过在模型的帮助下动态收集人类生成的对抗性数据,我们可以提高模型的泛化能力。

在对抗性收集的数据上欺骗和训练模型的过程应该在多个回合中重复进行1。Eric Wallace等人在他们对自然语言推理任务的实验中表明,虽然在短期内标准的非对抗性数据收集效果更好,但从长期来看,动态对抗数据收集的准确率明显更高。

使用🤗 Spaces,为你的模型动态收集对抗数据并对其进行训练变得相对容易。

Leave a Reply

Your email address will not be published. Required fields are marked *