Press "Enter" to skip to content

使用Transformers.js制作基于机器学习的网页游戏

在这篇博客文章中,我将向您展示我如何制作一个名为Doodle Dash的实时ML驱动的网页游戏,该游戏完全在您的浏览器中运行(多亏了Transformers.js)。本教程的目标是向您展示制作自己的ML驱动网页游戏有多容易…正好赶上即将到来的开源AI游戏马拉松(2023年7月9日)。如果您还没有参加游戏马拉松,请加入!

视频:Doodle Dash演示视频

  • 演示: Doodle Dash
  • 源代码: doodle-dash
  • 参加游戏马拉松: 开源AI游戏马拉松

概述

在我们开始之前,让我们先来谈谈我们将要创建的内容。这个游戏的灵感来自于Google的Quick, Draw!游戏,您会得到一个词语,然后神经网络有20秒的时间猜测您正在画的是什么(重复6次)。实际上,我们将使用他们的训练数据来训练我们自己的草图检测模型!您是不是喜欢开源? 😍

在我们的版本中,您将有一分钟的时间一次性画出尽可能多的物品。如果模型预测出了正确的标签,画布将被清除,然后您将得到一个新的单词。一直做下去,直到计时器结束!由于游戏在浏览器中本地运行,我们完全不必担心服务器延迟。模型能够随着您的绘画实时进行预测,每秒超过60次的预测… 🤯 太厉害了!

本教程分为3个部分:

  1. 训练神经网络
  2. 使用Transformers.js在浏览器中运行
  3. 游戏设计

1. 训练神经网络

训练数据

我们将使用Google的Quick, Draw!数据集的一个子集来训练我们的模型,该数据集包含345个类别的500万多个绘画样本。以下是数据集中的一些样本:

使用Transformers.js制作基于机器学习的网页游戏 四海 第1张

模型架构

我们将对apple/mobilevit-small进行微调,这是一个轻量级且适用于移动设备的Vision Transformer,它已经在ImageNet-1k上进行了预训练。它只有5.6M的参数(约20MB的文件大小),非常适合在浏览器中运行!有关更多信息,请查看MobileViT的论文和下面的模型架构。

使用Transformers.js制作基于机器学习的网页游戏 四海 第2张

微调

使用Transformers.js制作基于机器学习的网页游戏 四海 第3张

为了让博客文章(相对)简短,我们准备了一个Colab笔记本,展示了我们在数据集上微调apple/mobilevit-small的具体步骤。从高层次上讲,这包括以下步骤:

  1. 加载”Quick, Draw!”数据集。

  2. 使用MobileViTImageProcessor转换数据集。

  3. 定义我们的整理函数和评估指标。

  4. 使用MobileViTForImageClassification.from_pretrained加载预训练的MobileVIT模型。

  5. 使用TrainerTrainingArguments辅助类训练模型。

  6. 使用🤗 Evaluate评估模型。

注意:您可以在Hugging Face Hub上找到我们微调的模型。

2. 使用Transformers.js在浏览器中运行

什么是Transformers.js?

Transformers.js是一个JavaScript库,允许您直接在浏览器中运行🤗 Transformers(无需服务器)!它的设计目标是与Python库在功能上相当,这意味着您可以使用非常类似的API运行相同的预训练模型。

在幕后,Transformers.js使用ONNX Runtime,因此我们需要将我们微调的PyTorch模型转换为ONNX。

将我们的模型转换为ONNX

幸运的是,🤗 Optimum库使将您的微调模型转换为ONNX变得非常简单!最简单(也是推荐的方法)是:

  1. 克隆Transformers.js存储库并安装必要的依赖项:

    git clone https://github.com/xenova/transformers.js.git
    cd transformers.js
    pip install -r scripts/requirements.txt
  2. 运行转换脚本(它在内部使用Optimum):

    python -m scripts.convert --model_id <model_id>

    其中<model_id>是您要转换的模型的名称(例如Xenova/quickdraw-mobilevit-small)。

设置我们的项目

让我们首先使用Vite搭建一个简单的React应用:

npm create vite@latest doodle-dash -- --template react

接下来,进入项目目录并安装必要的依赖项:

cd doodle-dash
npm install
npm install @xenova/transformers

然后,可以通过运行以下命令启动开发服务器:

npm run dev

在浏览器中运行模型

运行机器学习模型需要大量计算,因此在单独的线程中执行推理非常重要。这样我们就不会阻塞主线程,主线程用于渲染UI和响应您的绘图手势😉。Web Workers API使这变得非常简单!

src目录中创建一个新文件(例如worker.js),并添加以下代码:

import { pipeline, RawImage } from "@xenova/transformers";

const classifier = await pipeline("image-classification", 'Xenova/quickdraw-mobilevit-small', { quantized: false });

const image = await RawImage.read('https://hf.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png');

const output = await classifier(image.grayscale());
console.log(output);

现在,我们可以在App.jsx文件中使用此worker,通过向App组件添加以下代码:

import { useState, useEffect, useRef } from 'react'
// ... 其他导入

function App() {
    // 创建对worker对象的引用。
    const worker = useRef(null);

    // 我们使用`useEffect`勾子来设置worker,一旦`App`组件被挂载。
    useEffect(() => {
        if (!worker.current) {
            // 如果worker不存在,则创建worker。
            worker.current = new Worker(new URL('./worker.js', import.meta.url), {
                type: 'module'
            });
        }

        // 创建一个用于接收来自worker线程的消息的回调函数。
        const onMessageReceived = (e) => { /* 查看代码 */ };

        // 将回调函数附加为事件监听器。
        worker.current.addEventListener('message', onMessageReceived);

        // 定义组件卸载时的清理函数。
        return () => worker.current.removeEventListener('message', onMessageReceived);
    });

    // ... 其他组件内容
}

您可以通过运行开发服务器(使用npm run dev),访问本地网站(通常为http://localhost:5173/),并打开浏览器控制台来测试一切是否正常。您应该在控制台中看到模型的输出被记录。

[{ label: "skateboard", score: 0.9980043172836304 }]

太棒了!🥳 虽然上面的代码只是最终产品的一小部分,但它展示了机器学习方面的简单性!其余部分只是使其看起来漂亮并添加一些游戏逻辑。

3. 游戏设计

在本节中,我将简要讨论游戏设计过程。作为提醒,您可以在GitHub上找到项目的完整源代码,所以我不会详细介绍代码本身。

充分利用实时性能

在浏览器中进行推理的主要优点之一是我们可以实时进行预测(每秒超过60次)。在原始的Quick, Draw!游戏中,模型每隔几秒钟才进行一次新的预测。我们可以在我们的游戏中做同样的事情,但这样我们就无法充分利用其实时性能!所以,我决定重新设计主游戏循环:

  • 我们的版本任务是在60秒内正确绘制尽可能多的涂鸦(每次一个提示),而不是六个持续20秒的回合(其中每个回合对应一个新词)。
  • 如果你遇到一个你无法绘制的词,你可以跳过它(但这将消耗你剩余时间的3秒)。
  • 在原始游戏中,由于模型每隔几秒钟会猜测一次,它可以逐渐从列表中划掉标签,直到最终猜对为止。在我们的版本中,我们会减少模型对前n个错误标签的得分,其中n会随着用户继续绘制而逐渐增加。

提高生活质量

原始数据集包含345个不同的类别,由于我们的模型相对较小(约20MB),有时无法正确猜测某些类别。为了解决这个问题,我们删除了一些词语,这些词语要么:

  • 与其他标签太相似(例如,“谷仓”与“房子”)
  • 太难理解(例如,“动物迁移”)
  • 绘制时难以细节充足(例如,“大脑”)
  • 含糊不清(例如,“蝙蝠”)

在过滤后,我们仍然剩下300多个不同的类别!

奖励:想出游戏名称

秉承开源开发的精神,我决定向Hugging Chat询问一些游戏名称的想法…不用说,它没有让我失望!

使用Transformers.js制作基于机器学习的网页游戏 四海 第4张

我喜欢“Doodle Dash”(建议#4)的谐音,所以我决定采用它。谢谢Hugging Chat!🤗


我希望您喜欢与我一起构建这个游戏!如果您有任何问题或建议,您可以在Twitter、GitHub或🤗 Hub上找到我。此外,如果您想改进游戏(游戏模式?道具?动画?音效?),请随意Fork项目并提交Pull Request!我很想看看您的成果!

PS:别忘了参加开源AI游戏大赛!希望这篇博文能激发您使用Transformers.js构建自己的网络游戏的灵感!😉我们在游戏大赛见!🚀

Leave a Reply

Your email address will not be published. Required fields are marked *