MXNet 特征点提取基本流程
以 Android 调用 MXNet 为例:
开源 MXNet 代码:incubator-mxnet
其中 Android 部分代码路径:incubator-mxnet/amalagamation/jni
文件 | 说明 |
---|---|
org_dmlc_mxnet_Predictor.h |
MXNet JNI 接口声明文件 |
predictor.cc |
MXNet JNI 接口实现文件 |
org/dmlc/mxnet/MxnetException.java |
MXNet JNI 接口相关的 Java 端报错文件(示例) |
org/dmlc/mxnet/Predictor.java |
MXNet JNI 接口相关的 Java 端接口文件(示例) |
分析一下 predictor.cc
文件的每个接口功能:
MXNet 的 JNI 接口
查看 org/dmlc/mxnet/Predictor.java
文件可以知道 MXNet 的 Android 端基本接口只有 4 个。
private native static long createPredictor(byte[] symbol, byte[] params, int devType, int devId, String[] keys, int[][] shapes);
private native static void nativeFree(long handle);
private native static float[] nativeGetOutput(long handle, int index);
private native static void nativeForward(long handle, String key, float[] input);
主要功能:
接口 | 描述 |
---|---|
createPredictor | 初始化 MXNet predictor |
nativeFree | 释放 MXNet 资源(关闭 MXNet 功能) |
nativeGetOutput | 获取特征点信息 |
nativeForward | 输入需要提取特征的元素数据 |
createPredictor
/*
* Class: org_dmlc_mxnet_Predictor
* Method: createPredictor
* Signature: ([B[BII[Ljava/lang/String;[[I)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_mxnet_Predictor_createPredictor
(JNIEnv *, jclass, jbyteArray symbol, jbyteArray params, jint devType, jint devId, jobjectArray keys, jobjectArray shapes);
创建 MXNet predictor (预测器),用于对图片数据提取特征。
参数名 | JNI 类型 | Java 类型 | 说明 |
---|---|---|---|
symbol | jbyteArray | byte[] | 模型 symbol 数据(字节流) |
params | jbyteArray | byte[] | 模型 params 数据(字节流) |
devType | jint | int | 机器学习使用的硬件类型,支持 CPU(1), GPU(2), CPU Pinned(3) 等 |
devId | jint | int | predictor 的设备 id (用于区分其它 MXNet 成员) |
keys | jobjectArray | String[] | 输入参数的名称,对于 feedforward 是 {"data"} |
shapes | jobjectArray | int[][] | 多组输入节点的 shape 数据 |
返回值 | jlong | long | 返回创建的 predictor 的句柄(通过该句柄使用不同的 MXNet predictor) |
nativeFree
/*
* Class: org_dmlc_mxnet_Predictor
* Method: nativeFree
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeFree
(JNIEnv *, jclass, jlong handle);
用于释放对应的 MXNet predictor 数据,回收资源。
参数名 | JNI 类型 | Java 类型 | 说明 |
---|---|---|---|
handle | jlong | long | predictor 句柄,用于找到对应 MXNet 数据进行释放 |
nativeGetOutput
/*
* Class: org_dmlc_mxnet_Predictor
* Method: nativeGetOutput
* Signature: (JI)[F
*/
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_mxnet_Predictor_nativeGetOutput
(JNIEnv *, jclass, jlong handle, jint index);
获取特征点数据(已经经过机器学习根据模型提取特征点)。
参数名 | JNI 类型 | Java 类型 | 说明 |
---|---|---|---|
handle | jlong | long | predictor 句柄(同上) |
index | jint | int | shape 数据索引,获取第 index 组 shape 数据(MXNet 支持多种检测 shape) |
nativeForward
/*
* Class: org_dmlc_mxnet_Predictor
* Method: nativeForward
* Signature: (JLjava/lang/String;[F)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeForward
(JNIEnv *, jclass, jlong handle, jstring key, jfloatArray input);
输入图片数据,用于提取特征点。
参数名 | JNI 类型 | Java 类型 | 说明 |
---|---|---|---|
handle | jlong | long | predictor 句柄(同上) |
key | jstring | String | 设置输入数据的参数名称 |
input | jfloatArray | float[] | 图片数据,注意把 [ Y, X, RGB ] 的维度转为 [ RGB, Y, X ] 的维度 |
其中 input (float[]) 数据需要 RGB 数据(不需要 Alpha 透明度),而且还要进行维度转换:
[ 行,列,色深(RGB) ] 转为 [ 色深(RGB),行,列 ]。
另外不同的模型对 RGB 值会有一些偏移,也需要注意不同模型的参数。
参考代码如下:
public float[] inputFromImage(Bitmap[] bmps, float meanR, float meanG, float meanB) {
if (bmps.length == 0) return null;
int width = bmps[0].getWidth();
int height = bmps[0].getHeight();
float[] buf = new float[height * width * 3 * bmps.length];
for (int x=0; x<bmps.length; x++) {
Bitmap bmp = bmps[x];
if (bmp.getWidth() != width || bmp.getHeight() != height)
return null;
int[] pixels = new int[ height * width ];
bmp.getPixels(pixels, 0, width, 0, 0, height, width);
int start = width * height * 3 * x;
for (int i=0; i<height; i++) {
for (int j=0; j<width; j++) {
int pos = i * width + j;
int pixel = pixels[pos];
buf[start + pos] = Color.red(pixel) - meanR;
buf[start + width * height + pos] = Color.green(pixel) - meanG;
buf[start + width * height * 2 + pos] = Color.blue(pixel) - meanB;
}
}
}
return buf;
}