介绍
ESRGAN是一个较新的的低分辨率转高分辨率的GAN模型,在SRGAN的基础上做了增强。
其论文在ESRGAN论文
其代码在ESRGAN仓库,该仓库只提供了简单的demo测试代码。完整的训练和测试代码在BasicSR仓库中。
如果要进一步学习,给出2篇论文综述作为参考:
综述1
综述2
初次运行ESRGAN
- 安装环境
conda install numpy pip install opencv-python==3.4.5.20 conda install python-lmdb pip install tensorboardX # 进入 https://pytorch.org/get-started/locally/ 找到安装pytorch合适的指令。我这里原来是Linux conda python3.6 CUDA10 conda install pytorch torchvision cudatoolkit=10.0 -c pytorch # 然而由于conda镜像没了,需要用pip了 pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl pip3 install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp37-cp37m-linux_x86_64.whl
- 拉代码
git clone https://github.com/xinntao/ESRGAN.git
下载模型到models中
https://pan.baidu.com/s/1-Lh6ma-wXzfH8NqeBtPaFQ
运行下面的代码,结果在result中。
python test.py models/RRDB_ESRGAN_x4.pth
python test.py models/RRDB_PSNR_x4.pth
初次使用BasicSR测试ESRGAN(SRGAN)模型
- 拉代码
git clone https://github.com/xinntao/BasicSR.git
- 进入codes文件夹
cd codes
- 修改 options/test/test_ESRGAN.json
- datasets dataroot_HR 将后面路径改为自己的训练数据文件夹,文件夹内存放的是png文件;或者改为lmdb文件。
- path root 改为自己的BasicSR项目路径
- 将刚刚在ESRGAN中用到的model放到pretrain_model_G的目录下面。
- 其他暂时不用动,我本机配置如下所示。
{ "name": "RRDB_ESRGAN_x4" , "suffix": "_ESRGAN" , "model": "srragan" , "scale": 4 , "gpu_ids": [0] , "datasets": { "test_1": { // the 1st test dataset "name": "set5" , "mode": "LRHR" , "dataroot_HR": "/root/addition_store/DIV2K_train_HR" } } , "path": { "root": "/home/student_docker/zlh/BasicSR" , "pretrain_model_G": "../experiments/pretrained_models/RRDB_ESRGAN_x4.pth" } , "network_G": { "which_model_G": "RRDB_net" // RRDB_net | sr_resnet , "norm_type": null , "mode": "CNA" , "nf": 64 , "nb": 23 , "in_nc": 3 , "out_nc": 3 , "gc": 32 , "group": 1 } }
- 运行测试代码
python test.py -opt options/test/test_ESRGAN.json
- 如果需要跑其他的测试代码,见其他测试
训练ESRGAN(SRGAN)模型
准备数据(DIV2K)
- 从DIV2K official page或者百度云下载
- 有几个方法可以让IO速度变快
- 将HDD改成SSD
- 将图片数据集改成更小的子图切片(sub-images)。见3和4
- 将原始数据改成lmdb格式。见5和6
- 修改codes/scripts/extract_subimgs_single.py文件的路径
input_folder = '/root/addition_store/DIV2K_train_HR' # 输入图片路径 save_folder = '/root/addition_store/DIV2K_train_HR_sub' # 输出图片路径
- 运行
python scripts/extract_subimgs_single.py
执行切片操作 - 修改codes/scripts/create_lmdb.py
img_folder = '/root/addition_store/DIV2K_train_HR_sub/*' # glob matching pattern lmdb_save_path = '/root/addition_store/DIV2K_train_HR_sub.lmdb' # must end with .lmdb mode = 2
- 运行
python scripts/create_lmdb.py
将数据改成lmdb格式
训练
-
修改options/train/train_ESRGAN.json
"name": "002_RRDB_ESRGAN_x4_DIV2K" "train" "dataroot_HR": "/root/addition_store/DIV2K_train_HR_sub.lmdb" "val" "dataroot_HR": "/root/addition_store/DIV2K_valid_HR" "path" "root": "/home/student_docker/zlh/BasicSR"
-
运行
python train.py -opt options/train/train_ESRGAN.json
-
tensorboard可视化
tensorboard --logdir=../tb_logger
进入http://localhost.localdomain:6006可看到训练过程