Tensorflow提供的pb模型,在转化过程中需要提供输入和输出Tensor信息,可以通过自带工具查看节点信息
tensorflow自带工具查看输入与输出节点
1、进入tensorflow目录,编译安装summarize_graph
bazel build tensorflow/tools/graph_transforms:summarize_graph
2、修改in_graph为.pb模型路径
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=tensorflow_inception_graph.pb
输出
Found 1 possible inputs: (name=IteratorGetNext, type=float(1), shape=[1,224,224,3])
No variables spotted.
Found 1 possible outputs: (name=NCHW_output, op=Transpose)
Found 38473178 (38.47M) const parameters, 0 (0) variable parameters, and 0 control_edges
Op types used: 301 Const, 100 Add, 67 Conv2D, 67 Dequantize, 32 Mul, 32 Relu, 1 DepthToSpace, 1 Placeholder, 1 Transpose
To use with tensorflow/tools/benchmark:benchmark_model try these arguments:
bazel run tensorflow/tools/benchmark:benchmark_model -- --graph=/Users/chensi/util/facebeauty/superscale/model/EDSR_x4.pb --show_flops --input_layer=IteratorGetNext --input_layer_type=float --input_layer_shape= --output_layer=NCHW_output
打印所有Tensor信息
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
gf = tf.GraphDef()
gf.ParseFromString(open('/Users/chensi/util/facebeauty/superscale/model/EDSR_x4.pb','rb').read())
for n in gf.node:
print ( n.name +' ===> '+n.op )
tensorflow模型转tflite模型
GraphDef(*.pb)转tflite
import tensorflow as tf
in_path = "Shinkai_53.pb"
input_arrays = ["generator_input"]
input_shapes = {"generator_input" :[1, 256, 256, 3]}
output_arrays = ["generator/G_MODEL/out_layer/Tanh"] #多个输出则用冒号隔开["out1","out2"]
# Tensorflow1.x和2.x不兼容,2.x不再使用from_frozen_graph
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(in_path, input_arrays, output_arrays, input_shapes)
converter.allow_custom_ops=True # 一定要加,否则可能会出现各种错误
# converter.post_training_quantize = True # 启用量化
tflite_model = converter.convert()
open("Shinkai_53.tflite", "wb").write(tflite_model)
SaveModel模型转tflite
import tensorflow as tf
TF_PATH = "tf_model"
TFLITE_PATH = "../../animegan/model/face_paint_512_v3.tflite"
converter = tf.lite.TFLiteConverter.from_saved_model(TF_PATH)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
with open(TFLITE_PATH, 'wb') as f:
f.write(tf_lite_model)
keras model转换成TFLite
loaded_keras_model = keras.models.load_model('./keras_model.h5')
keras_to_tflite_converter =
tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_tflite = keras_to_tflite_converter.convert()
with open('./tflite_models/keras_tflite.tflite', 'wb') as f:
f.write(keras_tflite)
onnx model转TFLite
from onnx_tf.backend import prepare
import onnx
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 先将ONNX模型转化为SaveModel模型
TF_PATH = "tf_model" # where the representation of tensorflow model will be stored
ONNX_PATH = "photo2cartoon_weights.onnx" # path to my existing ONNX model
onnx_model = onnx.load(ONNX_PATH) # load onnx model
tf_rep = prepare(onnx_model) # creating TensorflowRep object
tf_rep.export_graph(TF_PATH)
# 再将SaveModel模型转化为TFLite模型
TFLITE_PATH = "./photo2cartoon_weights.tflite"
converter = tf.lite.TFLiteConverter.from_saved_model(TF_PATH)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
with open(TFLITE_PATH, 'wb') as f:
f.write(tf_lite_model)
遇到问题
一些pb模型转化tflite过程出错,或者转化的tflite模型不能识别
上述脚本在使用一些1.x版本模型时可能会出现问题,1.x和2.x版本不兼容,有些op不支持,以及模型使用占位符等等,pb模型转tflite模型不一定能成功,最好能提供SaveModel模型,这是2.x版本主要使用的格式