Press "Enter" to skip to content

使用JAX入门

推动高性能数值计算和机器学习研究的未来

Lance Asper在Unsplash上的照片

介绍

JAX是由Google开发的Python库,用于在任何类型的设备上进行高性能数值计算(CPU、GPU、TPU等)。 JAX的主要应用之一是机器学习和深度学习研究开发,尽管该库主要设计为提供执行通用科学计算任务(高维矩阵运算等)所需的所有功能。

考虑到特定的高性能计算重点,JAX被设计为非常快速,构建在XLA(加速线性代数)之上。 XLA实际上是一个编译器,旨在加速线性代数运算,并可用于在其他框架(如TensorFlow和Pytorch)之后工作。此外,JAX数组的设计与Numpy遵循相同的原则,因此可以轻松迁移旧的Numpy代码到JAX,并通过GPU和TPU充分利用性能加速。

JAX的一些主要特点包括:

  • 即时(JIT)编译:JIT和加速硬件是使JAX比纯粹的Numpy更快的原因。使用jit()函数可以编译和缓存具有XLA内核的自定义函数。使用缓存将增加我们第一次运行函数时的总体执行时间,然后大大减少后续运行的时间。在使用缓存时,重要的是在需要时清除缓存,以避免过时的结果(例如全局变量的更改)。
  • 自动并行化:异步派发使得JAX向量可以进行惰性评估,只有在访问时才实例化内容(在计算完成之前将控制权返回给程序)。此外,为了实现图优化,JAX数组是不可变的(与惰性评估和图优化的类似概念适用于Apache Spark)。可以使用pmap()函数在多个GPU / TPU上并行计算。
  • 自动向量化:可以使用vmap()函数执行自动向量化以并行化操作。在向量化期间,将算法从操作单个值转换为一组值。
  • 自动微分:grad()函数可用于自动计算函数的梯度(导数)。特别是,JAX自动微分使得可以在深度学习范围之外开发通用的微分程序。可以通过递归、分支、循环进行微分,执行高阶微分(例如,雅可比矩阵和黑塞矩阵),并同时使用正向和反向模式微分。

因此,JAX能够为我们提供构建先进的深度学习模型所需的所有基础,但并不提供一些最常见的深度学习操作的开箱即用的高级工具(例如损失/激活函数、层等)。例如,ML训练期间学习的模型参数可以在JAX中存储在Pytree结构中。考虑到JAX提供的所有优势,一些以不同的DL为导向的框架已经构建在其之上,例如Haiku(由DeepMind使用)和Flax(由Google Brain使用)。

演示

作为本文的一部分,我们现在将看到如何使用JAX和Kaggle Mobile Price分类数据集[1]解决一个简单的分类问题,以预测手机的价格范围。本文中使用的所有代码(以及更多!)都可在我的GitHub和Kaggle账户上找到。

首先,我们需要确保在我们的环境中安装了JAX。

pip install jax

这时,我们可以导入必要的库和数据集(图1)。为了简化分析,我们在标签中仅使用2个类别的数据,同时缩小了特征的数量。

import pandas as pdimport jax.numpy as jnpfrom jax import gradfrom sklearn.preprocessing import StandardScalerfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import classification_reportimport matplotlib.pyplot as pltdf = pd.read_csv('/kaggle/input/mobile-price-classification/train.csv')df = df.iloc[:, 10:]df = df.loc[df['price_range'] <= 1]df.head()
图1:手机价格分类数据集(图片作者:作者)

数据集清洗后,我们现在可以将其分为训练集和测试集,并标准化输入特征,以确保它们都位于相同的范围内。此时,输入数据也转换为JAX数组。

X = df.iloc[:, :-1]y = df.iloc[:, -1]X_train, X_test, y_train, y_test = train_test_split(X, y,                                                     test_size=0.20,                                                     stratify=y)X_train, X_test, y_train, Y_test = jnp.array(X_train), jnp.array(X_test), \                                   jnp.array(y_train), jnp.array(y_test)scaler = StandardScaler()scaler.fit(X_train)X_train = scaler.transform(X_train)X_test = scaler.transform(X_test)

为了预测手机的价格范围,我们将从头开始创建一个逻辑回归模型。为此,我们首先需要创建一对辅助函数(一个用于创建Sigmoid激活函数,另一个用于二进制损失函数)。

def activation(r):    return 1 / (1 + jnp.exp(-r))def loss(c, w, X, y, lmbd=0.1):    p = activation(jnp.dot(X, w) + c)    loss = jnp.sum(y * jnp.log(p) + (1 - y) * jnp.log(1 - p)) / y.size    reg = 0.5 * lmbd * (jnp.dot(w, w) + c * c)     return - loss + reg 

现在我们可以创建训练循环并绘制结果(图2)。

n_iter, eta = 100, 1e-1w = 1.0e-5 * jnp.ones(X.shape[1])c = 1.0history = [float(loss(c, w, X_train, y_train))]for i in range(n_iter):    c_current = c    c -= eta * grad(loss, argnums=0)(c_current, w, X_train, y_train)    w -= eta * grad(loss, argnums=1)(c_current, w, X_train, y_train)    history.append(float(loss(c, w, X_train, y_train)))
图2:逻辑回归训练历史(图片作者:作者)

满意结果后,我们可以使用测试集对模型进行测试(图3)。

y_pred = jnp.array(activation(jnp.dot(X_test, w) + c))y_pred = jnp.where(y_pred > 0.5, 1, 0) print(classification_report(y_test, y_pred))
图3:测试数据的分类报告(图片作者:作者)

结论

正如本简短示例所示,JAX具有非常直观的API,紧密遵循Numpy约定,同时也可以使用相同的代码进行CPU/GPU/TPU使用。利用这些构建模块,可以创建经过设计优化性能的高度可定制的深度学习模型。

联系方式

如果您想了解我的最新文章和项目,请关注我的VoAGI并订阅我的邮件列表。以下是我的一些联系方式:

  • Linkedin
  • 个人网站
  • VoAGI个人资料
  • GitHub
  • Kaggle

参考文献

[1] “Mobile Price Classification” (ABHISHEK SHARMA). 访问地址:https://thecleverprogrammer.com/2021/03/05/mobile-price-classification-with-machine-learning/ (MIT许可证:https://github.com/alifrmf/Mobile-Price-Prediction-Classification-Analysis/tree/main )

Leave a Reply

Your email address will not be published. Required fields are marked *