参考:
使用spark-scala调用tensorflow2.0训练好的模型
1. 使用TF2训练并保存模型:
import tensorflow as tf
from tensorflow.keras import models,layers,optimizers
## 样本数量
n = 800
## 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10)
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)
Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0) # @表示矩阵乘法,增加正态扰动
## 建立模型
tf.keras.backend.clear_session()
inputs = layers.Input(shape = (2,),name ="inputs") #设置输入名字为inputs
outputs = layers.Dense(1, name = "outputs")(inputs) #设置输出名字为outputs
linear = models.Model(inputs = inputs,outputs = outputs)
linear.summary()
## 使用fit方法进行训练
linear.compile(optimizer="rmsprop",loss="mse",metrics=["mae"])
linear.fit(X,Y,batch_size = 8,epochs = 100)
tf.print("w = ",linear.layers[1].kernel)
tf.print("b = ",linear.layers[1].bias)
## 将模型保存成pb格式文件
export_path = "/your_path/tf2_linear"
linear.save(export_path, save_format="tf")
保存模型目录:
~/demo/your_path tree
.
└── tf2_linear
├── assets
├── saved_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
3 directories, 3 files
2. 使用Java加载模型并预测
查看模型细节(Java加载模型及预测需要)
~/demo/your_path saved_model_cli show --dir ./tf2_linear --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: serving_default_inputs:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
maven依赖
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.73</version>
</dependency>
Java代码
package com.ml.demo.tf;
import com.alibaba.fastjson.JSON;
import org.tensorflow.*;
public class PredictNN {
public static void main(String args[]){
Session session = SavedModelBundle.load("/your_path/tf2_linear",
"serve").session();
float[][] input = {
{2.6327686f, -9.201903f},
{ -1.3209248f, 8.569574f},
{ -5.6642127f, 3.3681698f},
{ 9.604832f, 5.9664965f},
{ -0.8812313f, -6.76733f}
};
System.out.println("input: \n" + JSON.toJSONString(input));
Tensor inputTensor = Tensor.create(input);
Tensor resultTensor = session.runner()
.feed("serving_default_inputs:0", inputTensor)
.fetch("StatefulPartitionedCall:0")
.run().get(0);
float[][] result = new float[input.length][1];
resultTensor.copyTo(result);
System.out.println("result: \n" + JSON.toJSONString(result));
session.close();
}
}