pytorch转tflite实践

这个世界总是有各种各样的幺蛾子,所以我们要做各种各样的转换,就像今天要写的pytorch模型需要被转换成tflite。下面就以pytorch-ssd模型为例,做一次pytorch转tflite的实践。

  • pth模型转换成onnx
    第一步把torch.save()存下的模型转换成onnx模型,代码如下
import torch
from vision.ssd.mobilenet_v3_ssd_lite import create_mobilenetv3_ssd_lite

model = create_mobilenetv3_ssd_lite(num_classes=2)
torch.load("CARN_model_checkpoint.pt",map_location='cpu')['state_dict'].items()},False)
model.load_state_dict(torch.load("Epoch-85-Loss-0.4889--Epoch-45-Loss-0.4090.pth",map_location='cpu'))
dummy_input = torch.randn(1,3,300,300)
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(model, dummy_input, "ssd_Epoch-45.onnx", verbose=True, input_names=input_names, output_names=output_names,opset_version=11)
  • onnx转换成tensorflow pb模型
    第二步把onnx模型转换成tensorflow pb模型
git clone https://github.com/onnx/onnx-tensorflow.git
cd onnx-tensorflow
git checkout v1.6.0-tf-1.15
pip install -e .
onnx-tf convert -i /path/to/input.onnx -o /path/to/output.pb

通过第二步操作就生成了pb模型。

  • 把nchw格式pb模型转换成nhwc格式pb模型
    因为pth和onnx模型都是nchw的layout,转换成pb之后layout没有变,而tflite和tensorflow模型是nhwc的layout格式的,所以需要再增加一步转换,把nchw格式pb模型转换成nhwc格式pb模型,其实原理就是增加tranpose算子,代码如下:
import tensorflow as tf
if not tf.__version__.startswith('1'):
  import tensorflow.compat.v1 as tf
from tensorflow.python.tools import optimize_for_inference_lib

graph_def_file = "..\output.pb"

tf.reset_default_graph()
graph_def = tf.GraphDef()
with tf.Session() as sess:
    # Read binary pb graph from file
    with tf.gfile.Open(graph_def_file, "rb") as f:
        data2read = f.read()
        graph_def.ParseFromString(data2read)
    tf.graph_util.import_graph_def(graph_def, name='')
    
    # Get Nodes
    conv_nodes = [n for n in sess.graph.get_operations() if n.type in ['Conv2D','MaxPool','AvgPool']]
    for n_org in conv_nodes:
        # Transpose input
        assert len(n_org.inputs)==1 or len(n_org.inputs)==2
        org_inp_tens = sess.graph.get_tensor_by_name(n_org.inputs[0].name)
        inp_tens = tf.transpose(org_inp_tens, [0, 2, 3, 1], name=n_org.name +'_transp_input')
        op_inputs = [inp_tens]
        
        # Get filters for Conv but don't transpose
        if n_org.type == 'Conv2D':
            filter_tens = sess.graph.get_tensor_by_name(n_org.inputs[1].name)
            op_inputs.append(filter_tens)
        
        # Attributes without data_format, NWHC is default
        atts = {key:n_org.node_def.attr[key] for key in list(n_org.node_def.attr.keys()) if key != 'data_format'}
        if n_org.type in['MaxPool', 'AvgPool','Conv2D']:
            st = atts['strides'].list.i
            stl = [st[0], st[2], st[3], st[1]]
            atts['strides'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))
        if n_org.type in ['MaxPool', 'AvgPool']:
            st = atts['ksize'].list.i
            stl = [st[0], st[2], st[3], st[1]]
            atts['ksize'] = tf.AttrValue(list=tf.AttrValue.ListValue(i=stl))

        # Create new Operation
        #print(n_org.type, n_org.name, list(n_org.inputs), n_org.node_def.attr['data_format'])
        op = sess.graph.create_op(op_type=n_org.type, inputs=op_inputs, name=n_org.name+'_new', dtypes=[tf.float32], attrs=atts) 
        out_tens = sess.graph.get_tensor_by_name(n_org.name+'_new'+':0')
        out_trans = tf.transpose(out_tens, [0, 3, 1, 2], name=n_org.name +'_transp_out')
        assert out_trans.shape == sess.graph.get_tensor_by_name(n_org.name+':0').shape
        
        # Update Connections
        out_nodes = [n for n in sess.graph.get_operations() if n_org.outputs[0] in n.inputs]
        for out in out_nodes:
            for j, nam in enumerate(out.inputs):
                if n_org.outputs[0] == nam:
                    out._update_input(j, out_trans)
        
    # Delete old nodes
    graph_def = sess.graph.as_graph_def()
    for on in conv_nodes:
        graph_def.node.remove(on.node_def)

    # Write graph
    tf.io.write_graph(graph_def, "", graph_def_file.rsplit('.', 1)[0]+'_toco.pb', as_text=False)

第三步后会生成output_toco.pb模型,即为nhwc格式的pb模型。

  • 把nhwc格式的pb模型转换成tflite模型
    通过tensorflow转换工具把nhwc格式的pb模型转换成tflite模型
tflite_convert.exe --graph_def_file=output_toco.pb --output_file=ssd.tflite --input_arrays=input --output_arrays=output,1099

此时就生成了ssd.tflite模型,之后可用tflite进行前向推理。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,456评论 5 477
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,370评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,337评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,583评论 1 273
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,596评论 5 365
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,572评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,936评论 3 395
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,595评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,850评论 1 297
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,601评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,685评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,371评论 4 318
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,951评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,934评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,167评论 1 259
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 43,636评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,411评论 2 342