模型转换

pytorch转onnx

import torch.utils.datafrom torch.autograd import Variablefrom squeezenet import squeezenet1_2import onnxruntimeimport numpy as npfrom onnxruntime.datasets import get_exampleimport cv2import  onnxmodel_file = '/home/bbt/qinghua/detetion/pytorch-mobilenet/model/model_best_squeezenet1_2_age0.926.pth.tar'num_class =3# create modelmodel=squeezenet1_2(pretrained=False,num_classes=num_class)model = torch.nn.DataParallel(model)# optionally resume from a checkpointcheckpoint = torch.load(model_file)model.load_state_dict(checkpoint['state_dict'])model.to('cpu')model.eval()# model.cpu()#accuracy(mode)# dummy_input = Variable(torch.randn(1, 3, 224, 224))input=cv2.imread('/home/bbt/age00.jpg')input=cv2.resize(input,(224,224))input=np.transpose(input, (2, 0, 1)).astype(np.float32)now_image1 = Variable(torch.from_numpy(input))dummy_input = now_image1.unsqueeze(0)input_names=['input']output_names=['output']torch_out = torch.onnx._export(model.module, dummy_input, "/home/bbt/qinghua/detetion/pytorch-mobilenet/model/age.onnx",                               verbose=True, input_names=input_names, output_names=output_names)#test onnx modelexample_model = get_example('/home/bbt/qinghua/detetion/pytorch-mobilenet/model/age.onnx')sess = onnxruntime.InferenceSession(example_model)result = sess.run([output_name], {input_name: dummy_input.data.numpy()})np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), result[0], decimal=3)print(result[0])

转换后的模型保存在age.onnx,还需要对age.onnx模型进行简化,

python -m onnxsim age.onnx age-sim.onnx

得到age-sim.onnx文件,后续转ncnn模型使用这个。

###onnx模型转ncnn

编译ncnn:

1.编译本地ncnn

git clone https://github.com/Tencent/ncnn.git

cd ncnn

mkdir -p build

cd build

cmake …

make -j4

编译后,在目录ncnn/build/tools/caffe下,分别有ncnn2mem和caffe2ncnn两个可执行文件:

  • caffe2ncnn 将caffemodel转换为ncnnmodel
  • ncnn2mem 对模型进行加密操作

模型转换:

./onnx2ncnn age-sim.onnx age.param age.bin

模型加密:

./ncnn2mem age.param age.bin age.id.h age.men.h

会生成aeg.param.bin, age.id.h, age.mem.h。

android移植

新建android-ncnn工程,可以参考https://blog.csdn.net/qq_33431368/article/details/85009758。

或者直接下载编译好的工程,https://github.com/chehongshu/ncnnforandroid_objectiondetection_Mobilenetssd/tree/master/MobileNetSSD_demo_single,我们以这个工程为例,直接修改为自己的模型。

首先将自己的模型文件age.param.bin,age.bin,标签文件label.txt(每行对应标签名) 拷贝到ncnnforandroid_objectiondetection_Mobilenetssd/MobileNetSSD_demo_single/app/src/main/asset/.

将age.id.h文件拷贝到ncnnforandroid_objectiondetection_Mobilenetssd/MobileNetSSD_demo_single/app/src/main/cpp/

修改ncnnforandroid_objectiondetection_Mobilenetssd/MobileNetSSD_demo_single/app/src/main/cpp/MobileNetssd.cpp

文件:

修改include “MobileNetSSD_deploy.id.h” 为include “age.id.h”

####a.输入:

由于我的输入图片是直接cv2.imread(‘tupian.jpg’),读取的为bgr格式,因此修改输入为,

in = ncnn::Mat::from_pixels((const unsigned char*)indata, ncnn::Mat::PIXEL_RGBA2BGR, width, height);

由于我没有归一化,注释掉一下行,

const float mean_vals[3] = {127.5f, 127.5f, 127.5f}; const float scale[3] = {0.007843f, 0.007843f, 0.007843f}; in.substract_mean_normalize(mean_vals, scale);// 归一化

####b.模型输入、输出名修改

按照age.id.h的输入,输出名,修改输入输出,

// 如果不加密是使用ex.input(“data”, in);

// BLOB_data在id.h文件中可见,相当于datainput网络层的id

ex.input(age_param_id::BLOB_input, in);

// 如果时不加密是使用ex.extract(“prob”, out);

//BLOB_detection_out.h文件中可见,相当于dataout网络层的id,输出检测的结果数据

ex.extract(age_param_id::BLOB_output, out);

到此,模型可以正常预测了。

参考:

输入问题:https://github.com/Tencent/ncnn/wiki/FAQ-ncnn-produce-wrong-result

参考:https://blog.csdn.net/qq_33431368/article/details/85009758

更多相关文章

  1. Android语言和文化适配
  2. Android设置输入框和软键盘动态悬浮
  3. The following classes could not be instantiated: - android.s
  4. Android(安卓)使用Room 生成不了数据库文件
  5. android之popupwindow显示文件列表
  6. android获取文件大小常用类
  7. Android中Sax解析与Dom解析xml文件
  8. hciconfig - HCI device configuration utility
  9. [置顶] 基于android2.3.5系统:Android动态库链接

随机推荐

  1. 项目实战7—Mysql实现企业级数据库主从复
  2. mysql值以数组格式转换为PHP数组
  3. MySQL多个连接到付款数据的日历表
  4. mysql5.6配置同步复制的新方法以及常见问
  5. 在mysql shell中显示没有表行的查询结果(
  6. SQL调优案例,MYSQL服务器CPU100%问题解决
  7. mysql中select列表可以有group列表中没有
  8. 确定SQL UPDATE是否更改了列的值
  9. 当python遇到mysql时,如何顺利安装mysql
  10. 如果没有匹配,则使用默认值执行左连接。