TORCH09-04:使用TorchC++实现Lenet-5


  实现Lenet-5模型,实现模型的训练与验证,并编写两种识别方式:
    1. 直接使用数据集中原始数据测试;
    2. 使用图片测试;


Lenet-5 模型

  • 把这个模型代码放入下面两个程序中编译。
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
// BatchNorm
// Dropout
class Lenet5 : public torch::nn::Module{
private:
    // 卷积特征运算
    torch::nn::Conv2d  conv1;
    torch::nn::Conv2d  conv2;
    torch::nn::Conv2d  conv3;
    torch::nn::Linear  fc1;
    torch::nn::Linear  fc2;

public:
    Lenet5():
    conv1(torch::nn::Conv2dOptions(1, 6, 5).stride(1).padding(2)),  // 1 * 28 * 28 -> 6 * 28 * 28 -> 6 * 14 * 14
    conv2(torch::nn::Conv2dOptions(6, 16, 5).stride(1).padding(0)),  // 6 * 14 * 14 -> 16 * 10 * 10 -> 16 * 5 * 5
    conv3(torch::nn::Conv2dOptions(16, 120, 5).stride(1).padding(0)), // 16 * 5 * 5 -> 120 * 1 * 1 (不需要池化)
    fc1(120, 84),  // 120 -> 84
    fc2(84, 10){  // 84 -> 10 (分量最大的小标就是识别的数字)
        // 注册需要学习的矩阵(Kernel Matrix)
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv3", conv3);
        register_module("fc1", fc1);
        register_module("fc2", fc2);
    }

    // override
    torch::Tensor forward(torch::Tensor x){  // {n * 1 * 28 * 28}
        // 1. conv
        x = conv1->forward(x);   // {n * 6 * 28 * 28}
        x = torch::max_pool2d(x, 2);   // {n * 6 * 14 * 14}
        x = torch::relu(x); // 激活函数 // {n * 6 * 14 * 14}
        // 2. conv
        x = conv2->forward(x);   // {n * 16 * 10 * 10}
        x = torch::max_pool2d(x, 2);   // {n * 16 * 5 * 5}
        x = torch::relu(x); // 激活函数 // {n * 16 * 5 * 5}
        // 3. conv
        x = conv3->forward(x);   // {n * 120 * 1 * 1}
        x = torch::relu(x); // 激活函数 // {n * 120 * 1 * 1}
        // 做数据格式转换
        x = x.view({-1, 120});   // {n * 120}
        // 4. fc
        x = fc1->forward(x);
        x = torch::relu(x);
        
        // 5. fc 
        x = fc2->forward(x);
        return  torch::log_softmax(x, 1);   // CrossEntryLoss = log_softmax + nll
    }

};

训练与验证main.cpp

template <typename  DataLoader> 
void train(std::shared_ptr<Lenet5> &model,  DataLoader &loader,  torch::optim::Adam &optimizer){
    model->train();
    // 迭代数据
    int n = 0;
    for(torch::data::Example<torch::Tensor, torch::Tensor> &batch: loader){
        torch::Tensor data   = batch.data;
        auto target          = batch.target;
        optimizer.zero_grad(); // 清空上一次的梯度
        // 计算预测值
        torch::Tensor y = model->forward(data);
        // 计算误差
        torch::Tensor loss = torch::nll_loss(y, target);
        // 计算梯度: 前馈求导
        loss.backward();
        // 根据梯度更新参数矩阵
        optimizer.step();
        // 为了观察效果,输出损失
        // std::cout << "\t|--批次:" << std::setw(2) << std::setfill(' ')<< ++n 
        //           << ",\t损失值:" << std::setw(8) << std::setprecision(4) << loss.item<float>() << std::endl;
    }

    // 输出误差
}
template <typename DataLoader>
void  valid(std::shared_ptr<Lenet5> &model, DataLoader &loader) {
    model->eval();
    // 禁止求导的图跟踪
    torch::NoGradGuard  no_grad;
    // 循环测试集
    double sum_loss = 0.0;
    int32_t num_correct = 0;
    int32_t num_samples = 0;
    for(const torch::data::Example<> &batch: loader){
        // 每个批次预测值
        auto data = batch.data;
        auto target = batch.target;
        num_samples += data.sizes()[0];
        auto y = model->forward(data);
        // 计算纯预测的结果
        auto pred = y.argmax(1);
        // 计算损失值
        sum_loss += torch::nll_loss(y, target, {}, at::Reduction::Sum).item<double>();
        // 比较预测结果与真实的标签值
        num_correct += pred.eq(target).sum().item<int32_t>();
    }
    // 输出正确值
    std::cout << std::setw(8) << std::setprecision(4) 
        << "平均损失值:" << sum_loss / num_samples 
        << ",\t准确率:" << 100.0 * num_correct / num_samples << " %" << std::endl;
}

int main(int argc, const char** argv){
    
    // 数据集
    auto  ds_train = torch::data::datasets::MNIST(".\\data", torch::data::datasets::MNIST::Mode::kTrain);
    auto  ds_valid = torch::data::datasets::MNIST(".\\data", torch::data::datasets::MNIST::Mode::kTest);

    // torch::data::transforms::Normalize<> norm(0.1307, 0.3081);
    torch::data::transforms::Stack<> stack;

    // 数据批次加载器
    // auto n_train = ds_train.map(norm);
    auto s_train = ds_train.map(stack);
    auto train_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(s_train), 1000); 

    // auto n_valid = ds_valid.map(norm);
    auto s_valid = ds_valid.map(stack);
    auto valid_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(s_valid), 1000); 

    // 1. 创建模型对象
    std::shared_ptr<Lenet5> model = std::make_shared<Lenet5>();
    // for(auto &batch: *train_loader){
    //     auto data = batch.data;
    //     auto target = batch.target;
    //     data = data.view({-1, 1, 28, 28});
    //     auto pred = model->forward(data);
    //     // pred  <-> target 存在误差,计算误差,计算调整5 * 5 核矩阵的依据,调整的方向是 loss(pred - target) -> 0 
    // }


    // 优化器(管理模型中可训练矩阵)
    torch::optim::Adam  optimizer = torch::optim::Adam(model->parameters(), torch::optim::AdamOptions(0.001)); // 根据经验一般设置为10e-4 
    
    std::cout<< "开始训练" << std::endl;
    int epoch = 20;
    int interval = 1;   // 从测试间隔
    for(int e = 0; e < epoch; e++){
        std::printf("第%02d论训练\n", e+1);
        train(model, *train_loader, optimizer);
        if (e  % interval == 0){
            valid(model, *valid_loader);
        }
    }
    std:: cout << "训练结束" << std::endl;
    torch::save(model, "lenet5.pt");
    return 0;
}

识别实现main.cpp

int main(){
    const char * data_filename = ".\\data";
    // 加载模型
    std::shared_ptr<Lenet5> model = std::make_shared<Lenet5>();
    torch::load(model, "lenet5.pt");

    // 一. 使用测试集中数据识别
    auto imgs = torch::data::datasets::MNIST(data_filename, torch::data::datasets::MNIST::Mode::kTest);
    // 取一张图像
    for(int i = 0; i < 10; i++){
        torch::data::Example<> example = imgs.get(i);
        // std::cout << "识别的数字是:" << example.target.item<int32_t>() << std::endl;  
        // 获取图像
        torch::Tensor  a_img = example.data;
        // 预测
        a_img = a_img.view({-1, 1, 28, 28});  // 我们的模型只接受4为的固定的数据格式(N * C * H * W)(NCHW格式)
        torch::Tensor  y = model->forward(a_img);
        int32_t result = y.argmax(1).item<int32_t>();
        std::cout << "识别的结果是:" << result << "->" << example.target.item<int32_t>() <<  std::endl;
    }
    
    std::cout << "----------------------------------" << std::endl;

    // 二. 使用图像文件来识别
    // 读取图像
    
    cv::Mat im = cv::imread("img_9_9.png");   // 换图像,测试是否准确
    cv::cvtColor(im, im, cv::COLOR_BGR2GRAY);    // 注意:png图是3-4通道,需要转换为1通道灰度图。
    // 转换为Tensor,处理成0-1之间的数字
    im.convertTo(im, CV_32FC1, 1.0f / 255.0f);
    torch::Tensor  t_img = torch::from_blob(im.data, {1, 28, 28});
    t_img = t_img.view({-1, 1, 28, 28});
    // 识别
    torch::Tensor  y_ = model->forward(t_img);
    int32_t pred = y_.argmax(1).item<int32_t>();
    std::cout << "识别的结果是:" << pred <<  std::endl;
    return 0;
}

编译脚本CMakeLists.txt

    cmake_minimum_required(VERSION 3.16)

    project(main)
    set(CMAKE_PREFIX_PATH  "C:/libtorch")
    set(CMAKE_CXX_FLAGS  "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
    find_package(Torch REQUIRED)

    # opencv的配置
    include_directories("C:/opencv_new/install/include")
    link_directories("C:/opencv_new/install/x64/vc16/lib")

    add_executable(main main.cpp)
    target_link_libraries(main "${TORCH_LIBRARIES}"  "opencv_core420d.lib" "opencv_imgcodecs420d.lib" "opencv_imgproc420d.lib" )
    set_property(TARGET main PROPERTY CXX_STANDARD 11)


最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 206,723评论 6 481
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 88,485评论 2 382
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 152,998评论 0 344
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 55,323评论 1 279
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 64,355评论 5 374
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,079评论 1 285
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,389评论 3 400
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,019评论 0 259
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 43,519评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,971评论 2 325
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,100评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,738评论 4 324
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,293评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,289评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,517评论 1 262
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,547评论 2 354
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,834评论 2 345