为了有效的使用深度学习模型,一个结构化的和面向对象的文件结构能够有效的帮助我们增加代码的可用性,帮助我们更快地进入主项目并专注于模型的核心(模型,培训等)。
本文介绍的模版是一个张量流项目模板,它结合了简单性,文件夹结构的最佳实践和良好的OOP设计。 主要的想法是,每次启动tensorflow项目时都会有很多东西,所以包装所有这些共享内容将帮助您在每次启动新的tensorflow项目时更改核心思想。
github:https://github.com/MrGemy95/Tensorflow-Project-Template
环境
- Python 3.6
- Tensorflow-gpu 1.8.0
目录结构
base
:这个目录下保存了模型类和训练类的抽象类文件,用来规范类结构标准。
model
:这个目录下保存了神将网络模型类文件。
trainer
:这个目录下保存了模型的训练类文件。
mains
:这个目录下保存了整个项目的启动文件。
data _loader
:这个目录下保存了数据处理相关的文件。
utils
:这个目录下保存了工具类文件。
├── base
│ ├── base_model.py - this file contains the abstract class of the model.
│ └── base_train.py - this file contains the abstract class of the trainer.
│
│
├── model - this folder contains any model of your project.
│ └── example_model.py
│
│
├── trainer - this folder contains trainers of your project.
│ └── example_trainer.py
│
├── mains - here's the main(s) of your project (you may need more than one main).
│ └── example_main.py - here's an example of main that is responsible for the whole pipeline.
│
├── data _loader
│ └── data_generator.py - here's the data_generator that is responsible for all data handling.
│
└── utils
├── logger.py
└── any_other_utils_you_need
主要文件结构
Base
-
Base model
基本模型是一个抽象类,我们定义的任何模型都必须继承此类,其背后的想法是所有模型之间存在很多共享的方法。
这些方法包括:
- Save -保存checkpoint文件.
- Load -加载checkpoint文件.
- Cur_epoch, Global_step counters -这两个变量用来追踪current epoch和global step.
- Init_Saver 初始化用于保存和加载检查点的保护程序的抽象函数,在要实现的模型中需要覆盖此函数。
- Build_model 定义模型的抽象函数,在实现的模型中覆盖这个函数。
import tensorflow as tf
class BaseModel:
def __init__(self, config):
self.config = config
# init the global step
self.init_global_step()
# init the epoch counter
self.init_cur_epoch()
# save function that saves the checkpoint in the path defined in the config file
def save(self, sess):
print("Saving model...")
self.saver.save(sess, self.config.checkpoint_dir, self.global_step_tensor)
print("Model saved")
# load latest checkpoint from the experiment path defined in the config file
def load(self, sess):
latest_checkpoint = tf.train.latest_checkpoint(self.config.checkpoint_dir)
if latest_checkpoint:
print("Loading model checkpoint {} ...\n".format(latest_checkpoint))
self.saver.restore(sess, latest_checkpoint)
print("Model loaded")
# just initialize a tensorflow variable to use it as epoch counter
def init_cur_epoch(self):
with tf.variable_scope('cur_epoch'):
self.cur_epoch_tensor = tf.Variable(0, trainable=False, name='cur_epoch')
self.increment_cur_epoch_tensor = tf.assign(self.cur_epoch_tensor, self.cur_epoch_tensor + 1)
# just initialize a tensorflow variable to use it as global step counter
def init_global_step(self):
# DON'T forget to add the global step tensor to the tensorflow trainer
with tf.variable_scope('global_step'):
self.global_step_tensor = tf.Variable(0, trainable=False, name='global_step')
def init_saver(self):
# just copy the following line in your child class
# self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
raise NotImplementedError
def build_model(self):
raise NotImplementedError
-
base_train
这是一个抽象类,需要在实现的训练文件中覆盖各个训练函数。
import tensorflow as tf
class BaseTrain:
def __init__(self, sess, model, data, config, logger):
self.model = model
self.logger = logger
self.config = config
self.sess = sess
self.data = data
self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
self.sess.run(self.init)
def train(self):
for cur_epoch in range(self.model.cur_epoch_tensor.eval(self.sess), self.config.num_epochs + 1, 1):
self.train_epoch()
self.sess.run(self.model.increment_cur_epoch_tensor)
def train_epoch(self):
"""
implement the logic of epoch:
-loop over the number of iterations in the config and call the train step
-add any summaries you want using the summary
"""
raise NotImplementedError
def train_step(self):
"""
implement the logic of the train step
- run the tensorflow session
- return any metrics you need to summarize
"""
raise NotImplementedError
Model
我们自己定义的模型类是BaseModel的子类,实现这个类的步骤如下:
- 继承父类
- 重载
build_model
和init_saver
方法。 - 初始化这两个方法。
from base.base_model import BaseModel
import tensorflow as tf
class ExampleModel(BaseModel):
def __init__(self, config):
super(ExampleModel, self).__init__(config)
self.build_model()
self.init_saver()
def build_model(self):
self.is_training = tf.placeholder(tf.bool)
self.x = tf.placeholder(tf.float32, shape=[None] + self.config.state_size)
self.y = tf.placeholder(tf.float32, shape=[None, 10])
# network architecture
d1 = tf.layers.dense(self.x, 512, activation=tf.nn.relu, name="dense1")
d2 = tf.layers.dense(d1, 10, name="dense2")
with tf.name_scope("loss"):
self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=d2))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_step = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.cross_entropy,
global_step=self.global_step_tensor)
correct_prediction = tf.equal(tf.argmax(d2, 1), tf.argmax(self.y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def init_saver(self):
# here you initialize the tensorflow saver that will be used in saving the checkpoints.
self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
Trainer
我们自己定义的训练类是BaseTrain的子类,实现这个类的步骤如下:
- 继承父类
- 重载
train_epoch
和train_step
方法。
from base.base_train import BaseTrain
from tqdm import tqdm
import numpy as np
class ExampleTrainer(BaseTrain):
def __init__(self, sess, model, data, config,logger):
super(ExampleTrainer, self).__init__(sess, model, data, config,logger)
def train_epoch(self):
loop = tqdm(range(self.config.num_iter_per_epoch))
losses = []
accs = []
for _ in loop:
loss, acc = self.train_step()
losses.append(loss)
accs.append(acc)
loss = np.mean(losses)
acc = np.mean(accs)
cur_it = self.model.global_step_tensor.eval(self.sess)
summaries_dict = {
'loss': loss,
'acc': acc,
}
self.logger.summarize(cur_it, summaries_dict=summaries_dict)
self.model.save(self.sess)
def train_step(self):
batch_x, batch_y = next(self.data.next_batch(self.config.batch_size))
feed_dict = {self.model.x: batch_x, self.model.y: batch_y, self.model.is_training: True}
_, loss, acc = self.sess.run([self.model.train_step, self.model.cross_entropy, self.model.accuracy],
feed_dict=feed_dict)
return loss, acc
Data Loader
这个文件可以定义一个文件处理类,用来对数据进行加载、预处理以及batch的生成。
import numpy as np
class DataGenerator:
def __init__(self, config):
self.config = config
# load data here
self.input = np.ones((500, 784))
self.y = np.ones((500, 10))
def next_batch(self, batch_size):
idx = np.random.choice(500, batch_size)
yield self.input[idx], self.y[idx]
Utils
-
logger
这里可以定义一个Logger
类用于tensorflow summary操作。
-
config
这里可以定义一个config
文件,用来解析json存储的模型配置:
{
"exp_name": "example",
"num_epochs": 10,
"num_iter_per_epoch": 10,
"learning_rate": 0.001,
"batch_size": 16,
"state_size": [784],
"max_to_keep":5
}
Main
这是执行项目的主文件,我们可以在这里定义模型的训练过程以及Forward过程。
训练过程的定义如下:
- 创建文件夹
- 创建
Session
对象 - 创建
Data
对象 - 创建
Model
对象 - 创建
Logger
对象 - 创建
Trainer
对象 - 调用
trainer
的train()
方法开始训练
import tensorflow as tf
from data_loader.data_generator import DataGenerator
from models.example_model import ExampleModel
from trainers.example_trainer import ExampleTrainer
from utils.config import process_config
from utils.dirs import create_dirs
from utils.logger import Logger
from utils.utils import get_args
def main():
# capture the config path from the run arguments
# then process the json configuration file
try:
args = get_args()
config = process_config(args.config)
except:
print("missing or invalid arguments")
exit(0)
# create the experiments dirs
create_dirs([config.summary_dir, config.checkpoint_dir])
# create tensorflow session
sess = tf.Session()
# create your data generator
data = DataGenerator(config)
# create an instance of the model you want
model = ExampleModel(config)
# create tensorboard logger
logger = Logger(sess, config)
# create trainer and pass all the previous components to it
trainer = ExampleTrainer(sess, model, data, config, logger)
#load model if exists
model.load(sess)
# here you train your model
trainer.train()
if __name__ == '__main__':
main()