0x00 背景知识
先放上一篇综述文章,对于理解NAS(网络结构搜索)的问题有很大的帮助:https://blog.csdn.net/c9Yv2cf9I06K2A9E/article/details/82321884
另外,DARTS搜索,强烈建议先看下inception的网络结构和nasnet的论文,DARTS的论文基础是建立在之上的,某种程度上可以看做是对nasnet的优化。
0x01 搜索思路
基于前人的经验(inception/nasnet),DARTS使用cell作为模型结构搜索的基础单元,所学习的单元堆叠成卷积网络,也可以递归连接形成递归网络。
cell内节点间先默认所有可能的操作连接,每个连接初始化权重参数值,结构搜索也就是训练这些权重参数,最终两节点间选取权重最大的操作作为最终结构参数。
训练过程中,交替训练网络结构参数和网络参数。
0x02 代码定义
genotype结构定义
normal=[(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 1), (‘skip_connect’, 0), (‘skip_connect’, 0), (‘dil_conv_3x3’, 2)], normal_concat=[2, 3, 4, 5]
取了genotype里的一个normal cell的定义及其对应的cell结构图首先说明下,这个定义的解释。DARTS搜索的也就是这个定义。
normal定义里(‘sep_conv_3x3’, 1)的0,1,2,3,4,5对应到图中的红色字体标注的。
从normal文字定义两个元组一组,映射到图中一个蓝色方框的节点(这个是作者搜索出来的结构,结构不一样,对应关系不一定是这样的)
sep_conv_xxxx表示操作,0/1表示输入来源
(‘sep_conv_3x3’, 1), (‘sep_conv_3x3’, 0) —-> 节点0
(‘sep_conv_3x3’, 0), (‘sep_conv_3x3’, 1) —-> 节点1
(‘sep_conv_3x3’, 1), (‘skip_connect’, 0) —-> 节点2
(‘skip_connect’, 0), (‘dil_conv_3x3’, 2) —-> 节点3
normal_concat=[2, 3, 4, 5] —-> cell输出c_{k}
DARTS搜索NOTE
首先明确,DARTS搜索实际只搜cell内结构,整个模型的网络结构是预定好的,比如多少层,网络宽度,cell内几个节点等;
在构建搜索的网络结构时,有几个特别的地方:
1.预构建cell时,采用的一个MixedOp:包含了两个节点所有可能的连接(genotype中的PRIMITIVES);
2.初始化了一个alphas矩阵,网络做forward时,参数传入,在cell里使用,搜索过程中所有可能连接都在时,计算mixedOp的输出,采用加权的形式。
3.训练过程对train数据每个step又切成两份: train和validate, train用来训练网络参数,validate用来训练结构参数。
0x03 关键代码片段
以下把代码中一些关键的,影响到理解DARTS的地方说明一下:
- file: train_search.py 第149行
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
logits = model(input)
loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
optimizer.step()
这里就是论文里近似后的交叉梯度下降,其中architect.step()是结构参数weights的梯度下降,optimizer.step()是网络参数的梯度下降。
- file: model_search.py
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride, False)
if 'pool' in primitive:
op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
self._ops.append(op)
def forward(self, x, weights):
return sum(w * op(x) for w, op in zip(weights, self._ops)) # weighted op
这个是MixedOp,两节点间操作把PRIMITIVES里定义的所有操作都连接上,计算输出时利用传入的weights进行加权。
- file: model_search.py第47行
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states)) # all nodes before can be input, mixop.
offset += len(states) #0, 2, 5, 9
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
self.ops[], 实际是14(2+3+4+5)个MixedOp,2+3+4+5的解释,对于第一个内部节点,有两个可能的输入(c{k-1}, c_{k-2}),对于第二个内部节点,有三个可能的输入(两个同节点1,另加上第一个节点)……
代码里,weights[],也是一个长度14的list,前2个对应到第一个节点的两个输入的权重,第3~5这3个元素对应到第二个节点的三个输入的权重……这就是上面代码里offset的作用
- file: architect.py 第11行
class Architect(object):
def __init__(self, model, args):
self.network_momentum = args.momentum
self.network_weight_decay = args.weight_decay
self.model = model
self.optimizer = torch.optim.Adam(self.model.arch_parameters(), #arch_parameters,
lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
需要注意的是Architect里optimizer优化器的参数是model.arch_parameters(), 这个对应到的是model_search.py里定义的._arch_parameters,及初始化的各节点连接的权重。
def _initialize_alphas(self):
k = sum(1 for i in range(self._steps) for n in range(2+i)) # 2+i, 2 for two inputs, i=0,1,2,3, nodes before this. 2+3+4+5
num_ops = len(PRIMITIVES)
self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
self._arch_parameters = [
self.alphas_normal,
self.alphas_reduce,
]
- file: model_search.py 第133行
def _parse(weights):
# weights: [2 + 3 + 4 + 5][len(PRIMITIVES)]
gene = []
n = 2
start = 0
for i in range(self._steps): #ch: steps = 4
end = start + n
print('start=', start, 'end=', end, 'n=', n)
W = weights[start:end].copy()
print(W) # ch: add
# chenhua: for x, -max(W[x][...]), W[][] is the parameters for architect. lambda elect out the OP weights most.
edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
print(edges)
for j in edges: #ch: j, edges mean op, all possible ops between two node
print(j)
k_best = None
for k in range(len(W[j])): #ch: k, the weights for possible connection?
if k != PRIMITIVES.index('none'):
if k_best is None or W[j][k] > W[j][k_best]:
print('W[j][k]=', W[j][k], 'W[j][k_best]=', W[j][k_best])
k_best = k
gene.append((PRIMITIVES[k_best], j)) #ch: find ????
start = end
n += 1
return gene
# ch: alphas_xxx, parameters for architect??
gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())
concat = range(2+self._steps-self._multiplier, self._steps+2) #ch: step=4, mltiplier=3
print('concat', concat)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
print('genotype=', genotype)
return genotype
搜索过程中搜索出的结果(节点间的op)的打印,就是靠这个函数。
核心是找出两个节点间不为none的所有ops中权重最大的,就是最终的结果。
注意:weights[][]的size是[2 + 3 + 4 + 5][len(PRIMITIVES)]