Press "Enter" to skip to content

农业中的视觉变压器 | 收获创新

介绍

农业一直是人类文明的基石,为全球数十亿人提供食物和生计。随着科技的进步,我们发现了增强农业实践的新颖方法。其中一项进展是使用视觉转换器(ViTs)来对作物的叶病进行分类。在本博客中,我们将探讨视觉转换器在农业中的革命性,通过提供一种高效准确的解决方案来识别和缓解作物病害。

木薯,又称木薯或椰菜,是一种多用途的作物,可用于提供日常主食和工业应用。它的耐寒能力和抗逆性使其成为在环境条件艰苦的地区必不可少的作物。然而,木薯植株容易受到各种病害的侵袭,其中CMD和CBSD是最具破坏性的病害之一。

CMD是由白蝗传播的病毒复合体引起的,导致木薯叶片出现严重的驳斑症状。而CBSD则是由两种相关病毒引起的,主要影响储存根,使其无法食用。及早识别这些病害对于防止作物大面积损害和确保粮食安全至关重要。视觉转换器是转换器架构的进化版本,最初设计用于自然语言处理(NLP),在处理视觉数据方面表现出高度有效性。这些模型将图像作为补丁的序列进行处理,使用自注意机制来捕捉数据中的复杂模式和关系。在木薯叶病分类的背景下,ViTs通过分析感染木薯叶子的图像来训练以识别CMD和CBSD。

学习成果

  • 了解视觉转换器及其在农业中的应用,特别是叶病分类方面。
  • 了解转换器架构的基本概念,包括自注意机制,以及如何将其适应于视觉数据处理。
  • 了解视觉转换器(ViTs)在农业中的创新应用,特别是对木薯叶病早期检测的应用。
  • 深入了解视觉转换器的优势,如可扩展性和全局上下文,以及它们面临的挑战,包括计算要求和数据效率。

本文是作为“数据科学博文马拉松”的一部分发表的。

视觉转换器的崛起

近年来,由于卷积神经网络(CNNs)的发展,计算机视觉取得了巨大的进步。CNNs一直是各种与图像相关的任务的首选架构,从图像分类到目标检测。然而,视觉转换器作为一种强大的替代方案崭露头角,提供了一种新颖的处理视觉信息的方法。Google Research的研究人员在2020年发布了一篇具有开创性的论文,题为“图像价值16×16个单词:大规模图像识别的转换器”。他们将最初设计用于自然语言处理(NLP)的转换器架构应用于计算机视觉领域。这种适应为该领域带来了新的可能性和挑战。

农业中的视觉变压器 | 收获创新 四海 第1张

使用ViTs相对于传统方法具有几个优势,包括:

  • 高准确性:ViTs在准确性方面表现出色,可以可靠地检测和区分叶病。
  • 高效性:经过训练后,ViTs可以快速处理图像,适用于实时病害检测。
  • 可扩展性:ViTs可以处理不同大小的数据集,适应不同的农业环境。
  • 泛化能力:ViTs可以泛化到不同的木薯品种和病害类型,减少针对每种情况的特定模型的需求。

转换器架构简介

在深入了解视觉转换器之前,了解转换器架构的核心概念是至关重要的。转换器最初为NLP而设计,革新了语言处理任务。转换器的关键特点是自注意机制和并行化,可以更全面地理解上下文并加快训练速度。

转换器的核心是自注意机制,它使模型在进行预测时可以权衡不同输入元素的重要性。这种机制与多头注意力层结合使用,可以捕捉数据中的复杂关系。

那么,视觉转换器如何将转换器架构应用于计算机视觉领域呢?视觉转换器的基本思想是将图像视为补丁的序列,就像NLP任务将文本视为单词的序列一样。然后,转换器层通过将图像中的每个补丁嵌入向量来处理它。

Vision Transformer的关键组件

农业中的视觉变压器 | 收获创新 四海 第2张

  • 图像切片嵌入:将图像分为固定大小的非重叠切片,通常为16×16像素。然后将每个切片线性嵌入到较低维度的向量中。
  • 位置编码:将位置编码添加到切片嵌入中,以考虑切片在图像中的空间排列。这使得模型能够学习图像中切片的相对位置。
  • Transformer编码器:Vision Transformer由多个Transformer编码器层组成,类似于NLP Transformer。每个层对切片嵌入执行自注意力和前馈操作。
  • 分类头:在Transformer层的末尾添加了一个分类头,用于图像分类等任务。它接收输出嵌入并产生类别概率。

引入Vision Transformer标志着与依赖卷积层进行特征提取的CNN的显著分离。通过将图像视为切片序列,Vision Transformer在各种计算机视觉任务中,包括图像分类、目标检测,甚至视频分析中实现了最先进的结果。

实现

数据集

木薯叶病数据集包含约15,000张高分辨率木薯叶图像,展示了各种病症的不同阶段和程度。每个图像都经过细致的标记,以指示存在的病症,从而可以进行监督式机器学习和图像分类任务。木薯病具有明显的特征,因此可以将其分类为几个类别。这些类别包括木薯细菌性枯萎病(CBB)、木薯褐条病(CBSD)、木薯绿斑病(CGM)和木薯花叶病毒病(CMD)。研究人员和数据科学家利用该数据集来训练和评估机器学习模型,包括像Vision Transformer(ViTs)这样的深度神经网络。

导入必要的库

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow_addons as tfa
import glob, random, os, warnings
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns csv

加载数据集

image_size = 224
batch_size = 16
n_classes = 5

train_path = '/kaggle/input/cassava-leaf-disease-classification/train_images'
test_path = '/kaggle/input/cassava-leaf-disease-classification/test_images'

df_train = pd.read_csv('/kaggle/input/cassava-leaf-disease-classification/train.csv', dtype = 'str')

test_images = glob.glob(test_path + '/*.jpg')
df_test = pd.DataFrame(test_images, columns = ['image_path'])

classes = {0 : "木薯细菌性枯萎病 (CBB)",
           1 : "木薯褐条病 (CBSD)",
           2 : "木薯绿斑病 (CGM)",
           3 : "木薯花叶病毒病 (CMD)",
           4 : "健康"}#import csv

数据增强

def data_augment(image):
    p_spatial = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype = tf.float32)
 
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k = 3) # 旋转270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k = 2) # 旋转180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k = 1) # 旋转90º
        
    return image#import csv

数据生成器

datagen = tf.keras.preprocessing.image.ImageDataGenerator(samplewise_center = True,
                                                          samplewise_std_normalization = True,
                                                          validation_split = 0.2,
                                                          preprocessing_function = data_augment)

train_gen = datagen.flow_from_dataframe(dataframe = df_train,
                                        directory = train_path,
                                        x_col = 'image_id',
                                        y_col = 'label',
                                        subset = 'training',
                                        batch_size = batch_size,
                                        seed = 1,
                                        color_mode = 'rgb',
                                        shuffle = True,
                                        class_mode = 'categorical',
                                        target_size = (image_size, image_size))

valid_gen = datagen.flow_from_dataframe(dataframe = df_train,
                                        directory = train_path,
                                        x_col = 'image_id',
                                        y_col = 'label',
                                        subset = 'validation',
                                        batch_size = batch_size,
                                        seed = 1,
                                        color_mode = 'rgb',
                                        shuffle = False,
                                        class_mode = 'categorical',
                                        target_size = (image_size, image_size))

test_gen = datagen.flow_from_dataframe(dataframe = df_test,
                                       x_col = 'image_path',
                                       y_col = None,
                                       batch_size = batch_size,
                                       seed = 1,
                                       color_mode = 'rgb',
                                       shuffle = False,
                                       class_mode = None,
                                       target_size = (image_size, image_size))#import csv

images = [train_gen[0][0][i] for i in range(16)]
fig, axes = plt.subplots(3, 5, figsize = (10, 10))

axes = axes.flatten()

for img, ax in zip(images, axes):
    ax.imshow(img.reshape(image_size, image_size, 3))
    ax.axis('off')

plt.tight_layout()
plt.show()#import csv

农业中的视觉变压器 | 收获创新 四海 第3张

模型构建

learning_rate = 0.001
weight_decay = 0.0001
num_epochs = 1

patch_size = 7  # 从输入图像中提取的补丁的大小
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # 变换器层的大小
transformer_layers = 8
mlp_head_units = [56, 28]  # 最终分类器的稠密层的大小

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = L.Dense(units, activation = tf.nn.gelu)(x)
        x = L.Dropout(dropout_rate)(x)
    return x

补丁创建

在我们的木薯叶病分类项目中,我们使用自定义层来方便地提取和编码图像补丁。这些专门的层在为Vision Transformer模型处理我们的数据时起到了重要作用。

class Patches(L.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images = images,
            sizes = [1, self.patch_size, self.patch_size, 1],
            strides = [1, self.patch_size, self.patch_size, 1],
            rates = [1, 1, 1, 1],
            padding = 'VALID',
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
        
plt.figure(figsize=(4, 4))

x = train_gen.next()
image = x[0][0]

plt.imshow(image.astype('uint8'))
plt.axis('off')

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size = (image_size, image_size)
)

patches = Patches(patch_size)(resized_image)
print(f'图像大小: {image_size} X {image_size}')
print(f'补丁大小: {patch_size} X {patch_size}')
print(f'每个图像的补丁数: {patches.shape[1]}')
print(f'每个补丁的元素数: {patches.shape[-1]}')

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))

for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype('uint8'))
    plt.axis('off')
    
class PatchEncoder(L.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = L.Dense(units = projection_dim)
        self.position_embedding = L.Embedding(
            input_dim = num_patches, output_dim = projection_dim
        )

    def call(self, patch):
        positions = tf.range(start = 0, limit = self.num_patches, delta = 1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded#import csv

补丁层 (class Patches(L.Layer)

补丁层通过从原始输入图像中提取补丁来启动我们的数据预处理流水线。这些补丁代表原始图像的较小、非重叠区域。该层对图像批次进行操作,提取特定大小的补丁并对其进行重新整形以进行进一步处理。这一步骤对于使模型关注图像中的细节非常重要,有助于捕捉复杂的模式。

图像补丁的可视化

在提取补丁之后,我们通过显示一张样本图像并叠加显示提取的补丁的网格来展示它们对图像的影响。这种可视化提供了关于图像如何被划分为这些补丁的见解,突出了补丁的大小以及从每个图像中提取的补丁数量。它有助于理解预处理阶段,并为后续的分析奠定基础。

补丁编码层 (class PatchEncoder(L.Layer)

一旦提取出补丁,它们将通过PatchEncoder层进一步处理。该层在编码每个补丁中包含的信息方面起到关键作用。它由两个关键组件组成:一个线性投影,增强补丁的特征,以及一个位置嵌入,增加空间上下文。最终得到的丰富的补丁表示对Vision Transformer的分析和学习至关重要,最终有助于模型在准确疾病分类方面的有效性。

自定义层Patches和PatchEncoder对于我们的木薯叶病分类数据预处理流程至关重要。它们使模型能够专注于图像补丁,增强其辨别相关模式和特征以实现精确疾病分类的能力。这个过程显著增强了我们的Vision Transformer模型的整体性能。

def vision_transformer():
    inputs = L.Input(shape =(image_size,image_size,3))
    
    #创建补丁。
    patches = Patches(patch_size)(inputs)
    
    #编码补丁。
    encoded_patches = PatchEncoder(num_patches,projection_dim)(patches)

    #创建多层Transformer块。
    for _ in range(transformer_layers):
        
        #层规范化1。
        x1 = L.LayerNormalization(epsilon = 1e-6)(encoded_patches)
        
        #创建多头注意力层。
        attention_output = L.MultiHeadAttention(
            num_heads = num_heads,key_dim = projection_dim,dropout = 0.1
        )(x1,x1)
        
        #跳过连接1。
        x2 = L.Add()([attention_output,encoded_patches])
        
        #层规范化2。
        x3 = L.LayerNormalization(epsilon = 1e-6)(x2)
        
        #MLP。
        x3 = mlp(x3,hidden_units = transformer_units,dropout_rate = 0.1)
        
        #跳过连接2。
        encoded_patches = L.Add()([x3,x2])

    #创建一个[batch_size,projection_dim]张量。
    representation = L.LayerNormalization(epsilon = 1e-6)(encoded_patches)
    representation = L.Flatten()(representation)
    representation = L.Dropout(0.5)(representation)
    
    #添加MLP。
    features = mlp(representation,hidden_units = mlp_head_units,dropout_rate = 0.5)
    
    #分类输出。
    logits = L.Dense(n_classes)(features)
    
    #创建模型。
    model = tf.keras.Model(inputs = inputs,outputs = logits)
    
    return model
    
decay_steps = train_gen.n // train_gen.batch_size
initial_learning_rate = learning_rate

lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate,decay_steps)

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_decayed_fn)

optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)

model = vision_transformer()
    
model.compile(optimizer = optimizer, 
              loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = 0.1), 
              metrics = ['accuracy'])

STEP_SIZE_TRAIN = train_gen.n // train_gen.batch_size
STEP_SIZE_VALID = valid_gen.n // valid_gen.batch_size

earlystopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_accuracy',
                                                 min_delta = 1e-4,
                                                 patience = 5,
                                                 mode = 'max',
                                                 restore_best_weights = True,
                                                 verbose = 1)

checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath = './model.hdf5',
                                                  monitor = 'val_accuracy', 
                                                  verbose = 1, 
                                                  save_best_only = True,
                                                  save_weights_only = True,
                                                  mode = 'max')

callbacks = [earlystopping, lr_scheduler, checkpointer]

model.fit(x = train_gen,
          steps_per_epoch = STEP_SIZE_TRAIN,
          validation_data = valid_gen,
          validation_steps = STEP_SIZE_VALID,
          epochs = num_epochs,
          callbacks = callbacks)
#import csv

代码解释

这段代码定义了一个专门用于木薯病分类任务的自定义Vision Transformer模型。它包含多个Transformer块,每个块由多头注意力层、跳过连接和多层感知器(MLP)组成。结果是一个强大的模型,能够捕捉木薯叶图像中的复杂模式。

首先,vision_transformer()函数成为主角,定义了我们的Vision Transformer的架构蓝图。该函数概述了模型如何处理和学习木薯叶图像,使其能够精确分类疾病。

为了进一步优化训练过程,我们实现了一个学习率调度器。该调度器采用余弦衰减策略,根据模型的学习情况动态调整学习率。这种动态适应性增强了模型的收敛性,使其能够高效地达到最佳性能。

一旦我们确定了模型的架构和训练策略,我们就开始进行模型的编译。在这个阶段,我们会指定一些关键的组件,比如损失函数、优化器和评估指标。这些元素会经过精心选择,以确保我们的模型能够优化其学习过程,做出准确的预测。

最后,我们通过应用训练回调来确保模型的训练效果。其中,早停和模型检查点是两个关键的回调。早停会监视模型在验证数据上的表现,并在改进停滞时进行干预,从而防止过拟合。同时,模型检查点会记录我们模型的最佳表现版本,以便将来使用。

这些组件共同构建了一个全面的框架,用于开发、训练和优化我们的视觉转换模型,这是实现准确的木薯叶病分类的关键步骤。

视觉转换模型在农业中的应用

视觉转换模型在木薯种植中的应用不仅仅是研究和新颖性的延伸,它还为当前的挑战提供了实际的解决方案:

  • 早期病害检测:视觉转换模型能够早期检测到木薯花叶病和木薯褐斑病,使农民能够及时采取行动,防止病害的传播,减少作物损失。
  • 资源效率:通过视觉转换模型,可以更有效地利用时间和劳动力资源,自动化的病害检测减少了对每棵木薯植株的手工检查的需求。
  • 精准农业:将视觉转换模型与其他技术(如无人机和物联网设备)结合,实现精准农业,可以精确定位和治疗病害热点。
  • 提高粮食安全:通过减轻病害对木薯产量的影响,视觉转换模型有助于提高木薯作为主要粮食作物的地区的粮食安全。

视觉转换模型的优势

视觉转换模型相对传统的基于卷积神经网络(CNN)的方法具有几个优势:

  • 可扩展性:视觉转换模型可以处理不同分辨率的图像,而无需对模型架构进行修改。在实际应用中,图像的大小会有所不同,这种可扩展性尤为重要。
  • 全局上下文:视觉转换模型中的自注意机制能够有效捕捉全局上下文。这对于识别杂乱场景中的对象非常重要。
  • 组件更少:与CNN不同,视觉转换模型不需要复杂的组件,如池化层和卷积滤波器。这简化了模型的设计和维护。
  • 迁移学习:视觉转换模型可以在大型数据集上进行预训练,使其成为迁移学习的理想选择。预训练模型可以通过相对较少的任务特定数据进行微调,用于特定任务。

挑战和未来方向

虽然视觉转换模型取得了显著的进展,但它们也面临着几个挑战:

  • 计算资源:训练大型视觉转换模型需要大量的计算资源,这对于小型研究团队和组织来说可能是一个障碍。
  • 数据效率:视觉转换模型对数据的需求较高,在有限的数据条件下实现稳健的性能可能是具有挑战性的。开发更加数据效率的训练技术是一个紧迫的问题。
  • 可解释性:转换模型常常被批评为黑盒子。研究人员正在努力改善视觉转换模型的可解释性,特别是在安全关键应用中。
  • 实时推理:使用大型视觉转换模型进行实时推理可能需要大量的计算资源。优化以实现更快推理速度是一个活跃的研究领域。

结论

视觉转换模型通过为叶病分类提供准确、高效的解决方案,改变了木薯种植业。它们处理视觉数据的能力,结合数据采集和模型训练的进展,对保护木薯作物和确保粮食安全具有巨大潜力。虽然仍然存在挑战,但持续的研究和实际应用推动了视觉转换模型在木薯种植业中的应用。持续的创新和合作将使视觉转换模型成为全球木薯农民的宝贵工具,为可持续农业实践做出贡献,减少由毁灭性叶病引起的作物损失。

要点总结

  • 视觉转换模型(ViTs)将转换器架构应用于计算机视觉,将图像处理为补丁序列。
  • ViTs最初是为计算机视觉设计的,现在也被应用于农业,以解决早期病害检测等挑战。
  • 解决计算资源和数据效率等挑战,使ViTs成为计算机视觉未来的有希望的技术。

常见问题

本文章中显示的媒体不归属于Analytics Vidhya,仅供作者自行决定使用。

Leave a Reply

Your email address will not be published. Required fields are marked *