- 使用'poly' 学习率策略,当前学习率是初始学习率乘以一个衰减系数。
- max_iter=epoch_iter*num_epoch
- epoch_iter:指的是读取数据有多少个batch
- num_epoch:指的是训练多少次
- lr_pow:0.9
- lr_encoder:2e-2
- beta1:0.9
- weight_decay: 1e-4
def adjust_learning_rate(optimizers, cur_iter, args):
scale_running_lr = ((1. - float(cur_iter) / args.max_iters) ** args.lr_pow)
args.running_lr_encoder = args.lr_encoder * scale_running_lr
args.running_lr_decoder = args.lr_decoder * scale_running_lr
(optimizer_encoder, optimizer_decoder) = optimizers
for param_group in optimizer_encoder.param_groups:
param_group['lr'] = args.running_lr_encoder
for param_group in optimizer_decoder.param_groups:
param_group['lr'] = args.running_lr_decoder
##########################################################
torch.optim.SGD(group_weight(net_encoder),
lr=args.lr_encoder,
momentum=args.beta1,
weight_decay=args.weight_decay)
# adjust learning rate
#for i ,data in enumerate(train_loader):
iterator_train = iter(loader_train)
for i in range(max_iter):
batch_data = next(iterator_train )
cur_iter = i + (epoch - 1) * args.epoch_iters
adjust_learning_rate(optimizers, cur_iter, args)