1. 关于mxnet的一些资料:
- dmlc/mxnet-notebook: github上搜吧, mxnet作者们出的一个很好的可以用来当做tutorial的东西,用ipython notebook过一遍。
- online document: http://mxnet.io/index.html 直接看API下的各种接口就可以了,但是严重不全...
- github上各种issue,事实证明大家的问题各不一样,总的数量又很少,反正感觉自己遇到的都是新问题。。
2. Module API
http://mxnet.io/api/python/module.html
The most widely used module class is Module, which wraps a Symbol
and one or more Executors. We construct a module by specify:
- symbol : the network Symbol
- context : the device (or a list of devices) for execution
- data_names : the list of data variable names
- label_names : the list of label variable names
The module API provides an intermediate- and high-level interface for performing computation with neural networks in MXNet. A module is an instance of subclasses of BaseModule. The most widely used module class is simply called Module, which wraps a Symbol and one or more Executors. For a full list of functions, see BaseModule. Each subclass of modules might have some extra interface functions. In this topic, we provide some examples of common use cases. All of the module APIs are in the mxnet.module namespace, simply called mxnet.mod.
- Module模块可以用来构建你的整个网络,比如
mxnet.mod.Module(...)
。除此之外,我们还要用fit()来指定输入(train, val)然后进行训练,可以执行predict(),score()等来得到输出,用set_params()等来将训练好的参数加载进网络, init_params()来初始化网络参数,可以选择不同的初始化方法,比如可以指定Xavier。总之对整个网络的各种设置、训练、初始化等都是由Module来wrap起来的。 - fit()这个函数想特别说一下,它可以指定train_data和val_data来进行训练和交叉验证,两个变量都是迭代器类型,也就是 io API下定义的。不过val是可选项,还可以在这里指定optimizer和它的params,进行初始化等工作。
-
另外想特别说的是bind(),因为在这里栽的挺惨的...bind的作用是得到一个真正可以跑的Executor,如果我们只是搭建了一个module,它是不可以用的,因为我们连内存都没有去申请。也就是说,bind之后才能fit。
要注意的是,这个是module下的bind,我们只需要给出要输入网络的data和label就可以了,data_shapes和label_shapes分别是namedtuple类型,在这里我们需要知道他们的name和shape,并按照网络的组织一步步得到每一层的参数大小信息,利用这个去申请cpu或者gpu内存。注意module中的每一层都会有一个自己的名字。用module去bind之后,可以得到整个网络的大小。感兴趣的可以看一下mxnet/python/mxnet/module/module.py下的bind()函数,它进入到executor_group.py下的DataParallelExecutorGroup()中,可以得到网络的一组executor,注意每一块symbol (后面会说到symbol,这里先理解为组成网络的每一层或者每一块)是对应一个executor的,这也就是为什么每一个module网络其实有一组executor。bind_exec()中会把每一个executor都存在一个list里面,贴一下代码比较直观
def bind_exec(self, data_shapes, label_shapes, shared_group):
"""
Bind executors on their respective devices.
Parameters
----------
data_shapes : list
label_shapes : list
shared_group : DataParallelExecutorGroup
"""
self.execs = []
for i in range(len(self.contexts)):
self.execs.append(self._bind_ith_exec(i, data_shapes, label_shapes, shared_group))
............(以下省略一坨)
self._bind_ith_exec()最后一句是:
executor = self.symbol.bind(ctx=context, args=arg_arrays, args_grad=grad_arrays, aux_states=aux_arrays, grad_req=self.grad_req, shared_exec=shared_exec)
这里的symbol.bind和module.bind是不一样的,一开始我给混了,其实这个bind是针对每一个symbol的,想看懂symbol和executor的可以移步这里。一定要明白:只有bind之后才真正有了executor,看到网上有人这样说:
Executor申请所有需要的内存和临时变量来进行计算。一旦绑定好,Executor会一直同一段内存空间做计算输入和输出。在执行一个符号表达式前,需要对所有的自由变量进行赋值,执行完forward后softmax的输出。
3. Model API
使用方法很简单,具体参考http://mxnet.io/api/python/model.html#model-api-reference
4. Symbol API
MXNet使用多值输出的符号表达式来声明计算图。符号是由操作子构建而来。一个操作子可以是一个简单的矩阵运算“+”,也可以是一个复杂的神经网络里面的层,例如卷积层。
# Output may vary
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)
net = mx.sym.SoftmaxOutput(data=net, name='out')
mx.viz.plot_network(net, shape={'data':(100,200)})
上述代码摘自mxnet notebook,它构建了一个非常简单的网络。可以看出,我们在搭建网络层次的时候需要指明输入、symbol的名字和一些其他的因结构而异的参数。
我们可以使用mxnet提供的各种神经网络层的符号接口,比如dropout, convolution, deconvolution, softmax等来搭建自己的网络。通常来讲,搭建复杂网络的时候我们可以在mxnet的symbol上再自己进行一次封装,比如把conv, batchnorm, relu层一起构建为一个网络模块方便调用。每一个层的命名也决定了这一层参数的命名,比如fc2层的参数名称默认为fc2_weight和fc2_bias。当我们fine tune一个网络,拿到的.params参数文件就是依靠这些名字把参数fit进网络的,通常来讲比如我们用ImageNet上的pre-trained网络来做我们自己的任务,如果最后层是fully-connected层那么很可能维度不一致,比如我们只分10类,而pre-trained网络分1000类。这个时候为了让set_params()的时候最后一层的参数不fit进去,我们可以把已有的net的最后层的名字改掉,这样参数就不会fit进去了,只需要重新初始化再去fine tune就可以了。
5. IO API
mxnet提供了几种不同的数据输入方式,最基本的都是从DataIter继承而来。比较常用的有mxnet.io.NDArrayIter, mxnet.io.ImageRecordIter。数据迭代器需要有reset()以进行下一轮迭代, 并需要提供读入下一组batch的next()接口,provide_data()和provide_label()提供输入输出数据的维度。
一种用于图片分类的常用数据迭代器就是ImageRecordIter,使用这个迭代器需要给出图片list文件,进而用im2rec生成可以读入的文件。它的缺点是对单张图片只能用单个label或者几个label,对于parsing任务来说就不能用了,只能用NADarrayIter,但是这个迭代器缺少一些必要的对图片的预处理,比如对图片进行在线crop。如果需要实现这些,就要我们自己去实现一个迭代器,实现的时候可以参考mxnet下example/fcn-xs这个例子。