01 Vision Transformer

使用 TensorFlow 构建 ViT B-16 模型。

1. 引言

在计算机视觉任务中通常使用注意力机制对特征进行增强或者使用注意力机制替换某些卷积层的方式来实现对网络结构的优化,这些方法都在原有卷积网络的结构中运用注意力机制进行特征增强

而 ViT 依赖于原有的编码器结构进行搭建,并将其用于图像分类任务,在减少模型参数量的同时提高了检测准确度。

将 Transformer 用于图像分类任务主要有以下 5 个过程:(1)将输入图像或特征进行序列化;(2)添加位置编码;(3)添加可学习的嵌入向量;(4)输入到编码器中进行编码;(5)将输出的可学习嵌入向量用于分类。结构图如下:


2. Patch Embedding

以 b×224×224×3 的输入图片为例。首先进行图像分块,将原图片切分为 14×14 个图像块(Patch),每个 Patch 的大小为 16×16,通过提取输入图片中的平坦像素向量,将每个输入 Patch 送入线性投影层,得到 Patch Embeddings。
在代码中,先经过一个 kernel=(16,16),strides=16 的卷积层划分图像块,再将 h和w 维度整合为 num_patches 维度,代表一共有 196 个 patch,每个 patch 为 16×16。

3. 添加类别标签和位置编码

为了输出融合了全局语义信息的向量表示,在第一个输入向量前添加可学习分类变量。经过编码器编码后,在最后一层输出中,该位置对应的输出向量就可以用于分类任务。与其他位置对应的输出向量相比,该向量可以更好的融合图像中各个图像块之间的依赖关系。
在 Transformer 更新的过程中,输入序列的顺序信息会丢失。Transformer 本身并没有办法学习这个信息,所以需要一种方法将位置表示聚合到模型的输入嵌入中。我们对每个 Patch 进行位置编码,该位置编码采用随机初始化,之后参与模型训练。与传统三角函数的位置编码方法不同,该方法是可学习的。
最后,将 Patch-Embeddings 和 class-token 进行堆叠,和 Position-Embeddings 进行叠加,得到最终嵌入向量,该向量输入给 Transformer 层进行后续处理。


4. 多头自注意力模块

Transformer 层中,主要包含多头注意力机制和多层感知机模块,下面先介绍多头自注意力模块。
单个的注意力机制,其每个输入包含三个不同的向量,分别为 Query向量(Q),Key向量(K),Value向量(V)。他们的结果分别由输入特征图和三个权重做矩阵乘法得到。



接着为每一个输入计算一个得分Score=q*k。
为了使梯度稳定,对 Score 的值进行归一化处理,并将结果通过 softmax 函数进行映射。之后再和 v 做矩阵相乘,得到加权后每个输入向量的得分 v。计算完后再乘以一个权重张量 W 提取特征。
计算公式如下,其中 \sqrt{d_{k}代表 K 向量维度的平方根。


5. MLP 多层感知器

这个部分简单来看就是两个全连接层提取特征,流程图如下。第一个全连接层通道上升4倍,第二个全连接层通道下降为原来。


6. 特征提取模块

Transformer 的单个特征提取模块是由 多头注意力机制 和 多层感知机模块 组合而成,encoder_block 模块的流程图如下。
输入图像像经过 LayerNormalization 标准化后,再经过我们上面定义的多头注意力模块,将输出结果和输入特征图残差连接,图像在特征提取过程中shape保持不变。

将输出结果再经过标准化,然后送入多层感知器提取特征,再使用残差连接输入和输出。



而 transformer 的特征提取模块是由多个 encoder_block 叠加而成,这里连续使用12个 encoder_block 模块来提取特征。

7. 主干网络

接下来就搭建网络了,将上面所有的模块组合到一起,如下图所示。

在下面代码中要注意的是 cls_ticks = x[:,0] 取出所有的类别标签。 因为在 cls_pos_embed 模块中,我们将 cls_token 和输入图像在 patch 维度上堆叠 layers.concate,用于学习每张特征图的类别信息,取出的类别标签 cls_ticks 的 shape 为 [b, 768]。最后经过一个全连接层得出每张图片属于每个类别的得分。


8. 查看模型结构

这里有个注意点,keras.Input() 的参数问题,创建输入层时,参数 shape 不需要指定batch维度,batch_shape 需要指定batch维度。

keras.Input(shape=None, batch_shape=None, name=None, dtype=K.floatx(), sparse=False, tensor=None)
'''
shape: 形状元组(整型),不包括batch size。for instance, shape=(32,) 表示了预期的输入将是一批32维的向量。
batch_shape: 形状元组(整型),包括了batch size。for instance, batch_shape=(10,32)表示了预期的输入将是10个32维向量的批次。
'''

接收模型后,通过 model.summary() 查看模型结构和参数量,通过 get_flops() 参看浮点计算量。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


# --------------------------------------------- #
# (1)Embedding 层
# inputs代表输入图像,shape为224*224*3
# out_channel代表该模块的输出通道数,即第一个卷积层输出通道数=768
# patch_size代表卷积核在图像上每16*16个区域卷积得出一个值
# --------------------------------------------- #
def patch_embed(inputs, out_channel, patch_size=16):
 
    # 获得输入图像的shape=[b,224,224,3]
    b, h, w, c = inputs.shape
 
    # 获得划分后每张图像的size=(14,14)
    grid_h, grid_w = h//patch_size, w//patch_size
 
    # 计算图像宽高共有多少个像素点 n = h*w
    num_patches = grid_h * grid_w
 
    # 卷积 [b,224,224,3]==>[b,14,14,768]
    x = layers.Conv2D(filters=out_channel, kernel_size=(patch_size,patch_size), strides=patch_size, padding='same')(inputs)
 
    # 维度调整 [b,h,w,c]==>[b,n,c]
    # [b,14,14,768]==>[b,196,768]
    x = tf.reshape(x, shape=[b, num_patches, out_channel])
 
    return x


# --------------------------------------------- #
# (2)类别标签和位置编码
# --------------------------------------------- #
def class_pos_add(inputs):
 
    # 获得输入特征图的shape=[b,196,768]
    b, num_patches, channel = inputs.shape
 
    # 类别信息 [1,1,768]
    # 直接通过classtoken来判断类别,classtoken能够学到其他token中的分类相关的信息
    cls_token = layers.Layer().add_weight(name='classtoken', shape=[1,1,channel], dtype=tf.float32,
                                          initializer=keras.initializers.Zeros(), trainable=True)  
 
    # 可学习的位置变量 [1,197,768], 初始化为0,trainable=True代表可以通过反向传播更新权重
    pos_embed = layers.Layer().add_weight(name='posembed', shape=[1,num_patches+1,channel], dtype=tf.float32,
                                          initializer=keras.initializers.RandomNormal(stddev=0.02), trainable=True)
 
    # 将类别信息在维度上广播 [1,1,768]==>[b,1,768]
    cls_token = tf.broadcast_to(cls_token, shape=[b, 1, channel])
 
    # 在num_patches维度上堆叠,注意要把cls_token放前面
    # [b,1,768]+[b,196,768]==>[b,197,768]
    x = layers.concatenate([cls_token, inputs], axis=1)
 
    # 将位置信息叠加上去
    x = tf.add(x, pos_embed)
 
    return x  # [b,197,768]


# --------------------------------------------- #
# (3)多头自注意力模块
# inputs: 代表编码后的特征图
# num_heads: 代表多头注意力中heads个数
# qkv_bias: 计算qkv是否使用偏置
# atten_drop_rate, proj_drop_rate:代表两个全连接层后面的dropout层
# --------------------------------------------- #
def attention(inputs, num_heads, qkv_bias=False, atten_drop_rate=0., proj_drop_rate=0.):
 
    # 获取输入特征图的shape=[b,197,768]
    b, num_patches, channel = inputs.shape
    # 计算每个head的通道数
    head_channel = channel // num_heads
    # 公式的分母,根号d
    scale = head_channel ** 0.5
 
    # 经过一个全连接层计算qkv [b,197,768]==>[b,197,768*3]
    qkv = layers.Dense(channel*3, use_bias=qkv_bias)(inputs)
    # 调整维度 [b,197,768*3]==>[b,197,3,num_heads,c//num_heads]
    qkv = tf.reshape(qkv, shape=[b, num_patches, 3, num_heads, channel//num_heads])
    # 维度重排 [b,197,3,num_heads,c//num_heads]==>[3,b,num_heads,197,c//num_heads]
    qkv = tf.transpose(qkv, perm=[2, 0, 3, 1, 4])
    # 获取q、k、v的值==>[b,num_heads,197,c//num_heads]
    q, k, v = qkv[0], qkv[1], qkv[2]
 
    # 矩阵乘法, q 乘 k 的转置,除以缩放因子。矩阵相乘计算最后两个维度
    # [b,num_heads,197,c//num_heads] * [b,num_heads,c//num_heads,197] ==> [b,num_heads,197,197]
    atten = tf.matmul(a=q, b=k, transpose_b=True) / scale
    # 对每张特征图进行softmax函数
    atten = tf.nn.softmax(atten, axis=-1)
    # 经过dropout层
    atten = layers.Dropout(rate=atten_drop_rate)(atten)
    # 再进行矩阵相乘==>[b,num_heads,197,c//num_heads]
    atten = tf.matmul(a=atten, b=v)
 
    # 维度重排==>[b,197,num_heads,c//num_heads]
    x = tf.transpose(atten, perm=[0, 2, 1, 3])
    # 维度调整==>[b,197,c]==[b,197,768]
    x = tf.reshape(x, shape=[b, num_patches, channel])
 
    # 调整之后再经过一个全连接层提取特征==>[b,197,768]
    x = layers.Dense(channel)(x)
    # 经过dropout
    x = layers.Dropout(rate=proj_drop_rate)(x)
 
    return x


# ------------------------------------------------------ #
# (4)MLP block
# inputs代表输入特征图;mlp_ratio代表第一个全连接层上升通道倍数;
# drop_rate代表杀死神经元概率
# ------------------------------------------------------ #
def mlp_block(inputs, mlp_ratio=4.0, drop_rate=0.):
 
    # 获取输入图像的shape=[b,197,768]
    b, num_patches, channel = inputs.shape
 
    # 第一个全连接上升通道数==>[b,197,768*4]
    x = layers.Dense(int(channel*mlp_ratio))(inputs)
    # GeLU激活函数
    x = layers.Activation('gelu')(x)
    # dropout层
    x = layers.Dropout(rate=drop_rate)(x)
 
    # 第二个全连接层恢复通道数==>[b,197,768]
    x = layers.Dense(channel)(x)
    # dropout层
    x = layers.Dropout(rate=drop_rate)(x)
 
    return x


# ------------------------------------------------------ #
# (5)单个特征提取模块
# num_heads:代表自注意力的heads个数
# epsilon:小浮点数添加到方差中以避免除以零
# drop_rate:自注意力模块之后的dropout概率
# ------------------------------------------------------ #
def encoder_block(inputs, num_heads, epsilon=1e-6, atten_drop_rate=0., proj_drop_rate=0., drop_rate=0.):
 
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(inputs)
    # 自注意力模块
    x = attention(x, num_heads=num_heads, atten_drop_rate=atten_drop_rate, proj_drop_rate=proj_drop_rate)
    # 残差连接输入和输出
    # x1 = x + inputs
    x1 = layers.add([x, inputs])
    
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(x1)
    # MLP模块
    x = mlp_block(x, drop_rate=drop_rate)
    # 残差连接
    # x2 = x + x1
    x2 = layers.add([x, x1])
 
    return x2  # [b,197,768]
 
# ------------------------------------------------------ #
# (6)连续12个特征提取模块
# ------------------------------------------------------ #
def transformer_block(x, num_heads):
 
    # 重复堆叠12次
    for _ in range(12):
        # 本次的特征提取块的输出是下一次的输入
        x = encoder_block(x, num_heads=num_heads)
 
    return x  # 返回特征提取12次后的特征图


# ---------------------------------------------------------- # 
# (7)主干网络
# batch_shape:代表输入图像的shape=[8,224,224,3]
# classes:代表最终的分类数
# drop_rate:代表位置编码后的dropout层的drop率
# num_heads:代表自注意力机制的heads个数
# epsilon:小浮点数添加到方差中以避免除以零
# ---------------------------------------------------------- # 
def VIT(batch_shape, classes, drop_rate=0., num_heads=12, epsilon=1e-6):
 
    # 构造输入层 [b,224,224,3]
    inputs = keras.Input(batch_shape=batch_shape)
 
    # PatchEmbedding层==>[b,196,768]
    x = patch_embed(inputs, out_channel=768)
 
    # 类别和位置编码==>[b,197,768]
    x = class_pos_add(x)
 
    # dropout层
    x = layers.Dropout(rate=drop_rate)(x)
 
    # 经过12次特征提取==>[b,197,768]
    x = transformer_block(x, num_heads=num_heads)
 
    # LayerNormalization
    x = layers.LayerNormalization(epsilon=epsilon)(x)
 
    # 取出特征图的类别标签,在第(2)步中我们把类别标签放在了最前面
    cls_ticks = x[:,0]
    # 全连接层分类
    outputs = layers.Dense(classes)(cls_ticks)
 
    # 构建模型
    model = keras.Model(inputs, outputs)
 
    return model


# ---------------------------------------------------------- # 
# (8)接收模型
# ---------------------------------------------------------- # 
if __name__ == '__main__':
 
    batch_shape = [8,224,224,3]  # 输入图像的尺寸
    classes = 1000  # 分类数
 
    # 接收模型
    model = VIT(batch_shape, classes)
 
    # 查看模型结构
    model.summary()
    
    # 查看浮点计算量 flops = 51955425272
    from keras_flops import get_flops
    print('flops:', get_flops(model, batch_size=8))

上述代码在tensorflow2.14.0中能运行,但在2.16.0中出现问题,结合官网给出程序修改,如下

# 1 Setup
# 环境准备
# 导入 TensorFlow、NumPy等相关库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import matplotlib.pyplot as plt

# 2 Prepare the data
# 数据导入和查看
num_classes = 10  # 分类数目为10个
input_shape = (32, 32, 3)  # 输入的图片形状为32x32像素,3个通道(RGB)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 加载 CIFAR10 数据集,并将其拆分为训练和测试集,x_train和x_test为图像数据,y_train和y_test为标签数据

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
# 打印训练集的图像数据形状和标签数据形状
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
# 打印测试集的图像数据形状和标签数据形状

#3 Configure the hyperparameters
# 配置超参数
learning_rate = 0.001  # 学习率,控制模型更新的步长大小
weight_decay = 0.0001  # 权重衰减,控制模型复杂度,防止过拟合
batch_size = 256  # 每次训练使用的样本数量
num_epochs = 10  # 训练的总轮数,真正训练,可设置为100
image_size = 72  # 将输入图像的大小调整为此大小
patch_size = 6  # 从输入图像中提取的图像块patch的大小
num_patches = (image_size // patch_size)**2  # 输入图像中的图像块数
projection_dim = 64  # Transformer模型中的投影维度,用于计算每个图像块的嵌入向量
num_heads = 4  # Transformer模型中的注意力头数,用于计算每个图像块的特征向量
transformer_units = [  # Transformer层的大小,每一层都有两个子层,一个是多头自注意力子层,一个是全连接子层
    projection_dim * 2,  # 第一子层的大小是投影维度的两倍
    projection_dim,  # 第二子层的大小是投影维度
]
transformer_layers = 8  # Transformer模型中Transformer层的数量
mlp_head_units = [2048, 1024]  # 最终分类器中的两个全连接层的大小

# 4 Use data augmentation
# 数据增强
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),  # 归一化图像数据
        layers.Resizing(image_size, image_size),  # 调整图像大小为指定大小
        layers.RandomFlip("horizontal"),  # 水平随机翻转图像
        layers.RandomRotation(factor=0.02),  # 随机旋转图像
        layers.RandomZoom(
            height_factor=0.2,
            width_factor=0.2  # 随机缩放图像
        ),
    ],
    name="data_augmentation",  # 给数据增强模型起个名字
)

data_augmentation.layers[0].adapt(x_train)  # 对训练集进行数据归一化,计算均值和方差用于后续的归一化处理


# 5 Implement multilayer perceptron (MLP)
# MLP 实现
def mlp(x, hidden_units, dropout_rate):
    # 定义一个MLP函数,其中参数x表示输入数据,hidden_units表示每一层MLP的神经元数,dropout_rate表示Dropout比率
    for units in hidden_units:
        # 循环遍历所有的隐藏层神经元数
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        # 全连接层,其中units为该层神经元个数,激活函数为gelu
        x = layers.Dropout(dropout_rate)(x)  # Dropout层,使部分神经元随机失活,防止过拟合
    return x  # 返回处理后的数据x


# 6 Implement patch creation as a layer
# 将patch创建实现为层
class Patches(layers.Layer):

    def __init__(self, patch_size):
        super().__init__()  # 继承父类的初始化方法
        self.patch_size = patch_size  # 图像块的大小

    def call(self, images):
        input_shape = tf.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = tf.image.extract_patches(  # 使用 TensorFlow 的图像处理 API 获取图像块
            images=images,  # 输入的图像
            sizes=[1, self.patch_size, self.patch_size, 1],  # 图像块的大小
            strides=[1, self.patch_size, self.patch_size, 1],  # 滑动步长
            rates=[1, 1, 1, 1],  # 对输入数据进行扩展的因素
            padding="VALID",  # 填充方式
        )
        patches = tf.reshape(patches, [
            batch_size, num_patches_h * num_patches_w,
            self.patch_size * self.patch_size * channels
        ])  # 对图像块进行形状变换
        return patches  # 返回处理后的图像块

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config


# 7 Implement the patch encoding layer
# 实现 patch 的编码
# 添加类别编码,并将可学习的位置嵌入到投影向量中
class PatchEncoder(layers.Layer):
    # 定义一个类,继承自Layer类
    def __init__(self):
        super().__init__()

    def build(self, input_shape):
        self.num_patches = input_shape[-2]
        self.projection_dim = input_shape[-1]
        # 类别信息 [1,1,768]
        #直接通过classtoken来判断类别,classtoken能够学到其他token中的分类相关的信息
        self.cls_token = self.add_weight(
            shape=(1, 1, self.projection_dim),
            dtype=tf.float32,
            initializer=keras.initializers.Zeros(),
            trainable=True,
            name='cls',
        )

        self.pe = self.add_weight(
            shape=[1, self.num_patches + 1, self.projection_dim],
            dtype=tf.float32,
            initializer=keras.initializers.RandomNormal(stddev=0.02),
            trainable=True,
            name='pos_embedding',
        )
        super(PatchEncoder, self).build(input_shape)

    def call(self, patch):
        batch_size = tf.shape(patch)[0]  # 获取图片的批次大小
        # 将类别信息在维度上广播 [1,1,768]==>[b,1,768]
        cls_broadcasted = tf.cast(tf.broadcast_to(
            self.cls_token, shape=[batch_size, 1, self.projection_dim]),
                                  dtype=patch.dtype)
        # 定义call方法,用于前向传播
        x = tf.concat([cls_broadcasted, patch], 1)
        # 在num_patches维度上堆叠,注意要把cls_token放前面
        # [b,1,768]+[b,196,768]==>[b,197,768]
        encoded = x + tf.cast(self.pe, dtype=patch.dtype)
        # 再加上嵌入的位置信息
        return encoded
        # 返回编码结果


# --------------------------------------------- #
# 自行编写:多头自注意力模块
# inputs: 代表编码后的特征图
# num_heads: 代表多头注意力中heads个数
# qkv_bias: 计算qkv是否使用偏置
# atten_drop_rate, proj_drop_rate:代表两个全连接层后面的dropout层
# --------------------------------------------- #
class attention(layers.Layer):

    def __init__(self,
                 num_heads,
                 projection_dim=64,
                 qkv_bias=False,
                 atten_drop_rate=0.,
                 proj_drop_rate=0.):
        super().__init__()
        self.num_heads = num_heads
        self.projection_dim = projection_dim
        self.qkv_bias = qkv_bias
        # 计算每个head的通道数
        self.head_channel = self.projection_dim // self.num_heads
        # 公式的分母,根号d
        self.scale = self.head_channel**0.5
        self.drop1 = layers.Dropout(rate=atten_drop_rate)
        self.dense1 = layers.Dense(self.projection_dim)
        self.drop2 = layers.Dropout(rate=proj_drop_rate)

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        # 获取输入特征图的shape=[b,197,768]
        # 调整维度 [b,197,768*3]==>[b,197,3,num_heads,c//num_heads]
        inputs = tf.reshape(
            inputs, [batch_size, -1, 3, self.num_heads, self.head_channel])
        # 公式的分母,根号d
        scale = self.head_channel**0.5

        # 维度重排 [b,197,3,num_heads,c//num_heads]==>[3,b,num_heads,197,c//num_heads]
        inputs = tf.transpose(inputs, perm=[2, 0, 3, 1, 4])
        # 获取q、k、v的值==>[b,num_heads,197,c//num_heads]
        q, k, v = inputs[0], inputs[1], inputs[2]

        # 矩阵乘法, q 乘 k 的转置,除以缩放因子。矩阵相乘计算最后两个维度
        # [b,num_heads,197,c//num_heads] * [b,num_heads,c//num_heads,197] ==> [b,num_heads,197,197]
        atten = tf.matmul(a=q, b=k, transpose_b=True) / scale
        # 对每张特征图进行softmax函数
        atten = tf.nn.softmax(atten, axis=-1)
        # 经过dropout层
        atten = self.drop1(atten)
        # 再进行矩阵相乘==>[b,num_heads,197,c//num_heads]
        atten = tf.matmul(a=atten, b=v)

        # 维度重排==>[b,197,num_heads,c//num_heads]
        x = tf.transpose(atten, perm=[0, 2, 1, 3])
        # 维度调整==>[b,197,c]==[b,197,768]
        x = tf.reshape(x, [batch_size, -1, self.projection_dim])

        # 调整之后再经过一个全连接层提取特征==>[b,197,768]
        x = self.dense1(x)
        # 经过dropout
        x = self.drop2(x)
        return x


# 8 Build the ViT model
# 建立 ViT 模型
# ViT模型由多个Transformer块组成,每个块使用layers.MultiHeadAttention层作为自注意机制
# Transformer块生成一个[batch_size,num_patches,projection_dim]张量
# 通过softmax分类器头处理以生成最终的类别概率输出。
def create_vit_classifier():
    # 输入数据形状。
    inputs = layers.Input(shape=input_shape)

    # 数据增强。
    augmented = data_augmentation(inputs)

    # 将图片切分成patch并embedding,
    # 必须使用超参数patch_size,有两种方式:
    # 1. 直接卷积,这种方便,但灵活性差点
    # projection_1= layers.Conv2D(projection_dim,
    #                        patch_size,
    #                        strides=patch_size,
    #                        padding="valid",
    #                        name="patch_embed.proj")(augmented)
    # projection = layers.Reshape(
    #     ((image_size // patch_size) * (image_size // patch_size),
    #      projection_dim))(projection_1)
    # 2. 实现Patches类,并全连接映射,使用超参数projection_dim
    patches = Patches(patch_size)(augmented)
    projection = layers.Dense(units=projection_dim)(patches)  # [batch,196,768]

    # 编码图像拼接块。
    encoded_patches = PatchEncoder()(projection)

    # 创建多个Transformer块。
    for _ in range(transformer_layers):
        # 第一层归一化。
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)

        #创建多头注意力层。
        qkv = layers.Dense(int(projection_dim * 3))(x1)
        attention_output = attention(
            num_heads=num_heads,
            projection_dim=projection_dim,
            # 注意:head_channel=projection_dim//num_heads
            qkv_bias=False,
            atten_drop_rate=0.,
            proj_drop_rate=0.)(qkv)

        # attention_output = layers.MultiHeadAttention(num_heads=num_heads,
        #                                              key_dim=projection_dim,
        #                                              dropout=0.1)(x1, x1)

        # 跳跃连接1
        x2 = layers.Add()([attention_output, encoded_patches])
        # 第二层归一化
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # 跳跃连接2
        encoded_patches = layers.Add()([x3, x2])

    # 创建一个 [batch_size, projection_dim] 张量。
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)

    # 取出特征图的类别标签,在第(2)步中我们把类别标签放在了最前面
    cls_ticks = layers.Lambda(lambda v: v[:, 0],
                              name="ExtractToken")(representation)
    # 全连接层分类
    logits = layers.Dense(num_classes)(cls_ticks)
    soft = layers.Softmax()(logits)
    # 创建Keras模型。
    model = keras.Model(inputs=inputs, outputs=soft)
    return model


# 9 Compile, train, and evaluate the mode
# 编译、训练和评估模型
def run_experiment(model):
    # 定义优化器
    optimizer = tf.optimizers.AdamW(learning_rate=learning_rate,
                                    weight_decay=weight_decay)

    # 编译模型,指定优化器和损失函数,同时定义评价指标
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5,
                                                        name="top-5-accuracy"),
        ],
    )

    # 设定模型训练过程中的回调函数,用于保存模型参数
    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",  # 监控的评价指标
        save_best_only=True,  # 仅保存最好的模型
        save_weights_only=True,  # 仅保存模型参数
    )

    # 训练模型
    history = model.fit(
        x=x_train,  # 输入特征
        y=y_train,  # 输入标签
        batch_size=batch_size,  # 批次大小
        epochs=num_epochs,  # 训练轮数
        validation_split=0.1,  # 用于验证的数据比例
        callbacks=[checkpoint_callback],  # 回调函数列表
    )

    # 加载保存的最优模型参数,并在测试集上进行评估
    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test,
                                                 y_test)  # 返回损失函数值、评价指标值
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # 返回训练历史记录
    return history


# 创建一个 VIT 分类模型
vit_classifier = create_vit_classifier()
# 运行训练实验
history = run_experiment(vit_classifier)


def plot_history(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("loss")
plot_history("top-5-accuracy")

原文链接:https://blog.csdn.net/dgvv4/article/details/124792386
https://zhuanlan.zhihu.com/p/626375905
https://keras.io/examples/vision/image_classification_with_vision_transformer/

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,362评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,330评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,247评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,560评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,580评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,569评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,929评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,587评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,840评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,596评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,678评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,366评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,945评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,929评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,165评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,271评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,403评论 2 342

推荐阅读更多精彩内容