近期需要将pytorch模型运行到android手机上实验,在查阅网上博客后,发现大多数流程需要借助多个框架或软件,横跨多个编程语言、IDE。本文参考以下两篇博文,力求用更简洁的流程实现模型部署。
https://blog.csdn.net/xiaodidididi521/article/details/123985612
https://blog.csdn.net/m0_67391683/article/details/125401357
一、pytorch模型转化
pytorch模型无法直接被Android调用,需要转化为特定格式.pt。本文使用pycharm IDE完成这一步,工程目录结构如下:
![pycharm目录结构](https://img–blog.csdnimg.cn/d67266301c3f43bfa20d3585dc5fe836.png#pic_center
其中,vgg16bn_CIFAR10.pth和另一个pth文件是需要部署到手机上的模型,models.py是自己定义的网络结构。在此默认读者熟悉pytorch,对models.py不做赘述。
import torch.utils.data.distributed
'定义转化后的模型名称'
model_ori_pt ='model_ori.pt'
model_pruned_pt ='model_pruned.pt'
'加载pytorch模型'
model_ori = torch.load('vgg16bn_CIFAR10.pth')
model_pruned = torch.load('vgg16bn_CIFAR10_pruned.pth')
'模型在cpu上运行'
device = torch.device('cpu')
model_ori.to(device)
model_pruned.to(device)
model_ori.eval()
model_pruned.eval()
'定义输入图片的大小'
input_tensor = torch.rand(1, 3, 32, 32)
'转化模型并存储'
mobile_ori = torch.jit.trace(model_ori, input_tensor)
model_pruned = torch.jit.trace(model_pruned, input_tensor)
mobile_ori.save(model_ori_pt)
model_pruned.save(model_pruned_pt)
请注意,让模型在cpu上,或cuda上执行eval()均可,但要保证模型与input_tensor在同一设备上,否则将运行出错。运行后,会得到model_ori.pt与model_pruned.pt两个文件,即可以用于android上的文件。此时目录结构如下:
二、新建Android Studio工程
首先,需要在本地安装Android Studio,安装流程建议参照:
https://m.runoob.com/android/android–studio–install.html?ivk_sa=1024320u
然后打开Android Studio新建Empy Activity
点击Finsh。SDK建议选择7.0以往的安卓版本。**首次新建工程底部会长时间出现加载进度条,请耐心等待加载完成。**接下来,我们需要有一部手机调试工程,本文使用Android Studio自带的模拟器。首先点击顶部工具栏的Device Manager。
点击create device
接下来选择机型、安卓版本、内存等,如不想麻烦可一直点击next。
finsh后,Android Studio需要下载安卓版本包,需要耐心等待。下载完成后即可启动虚拟机。
三、转化后的模型部署安卓
首先,新建assets文件夹,请不要直接新建,需右键app->Folder->Assets Folder。
之后将转化好的两个模型及侧视图放入assets文件夹。本文使用的是CIFAR10数据集,可在以下网址下载:
http://www.cs.toronto.edu/~kriz/cifar.html
然后在gradle Scripts 文件夹中的build.gradle(Module :app)文件中的depencies里添加:
implementation 'org.pytorch:pytorch_android:1.12.1'
implementation 'org.pytorch:pytorch_android_torchvision:1.12.1'
请注意**1.12.1是本文使用的pytorch版本,读者应该为对应的版本号。**然后点击工具栏下的sync now,再耐心等待运行按钮变绿。
双击res->layout->activity_main.xml并切换到code。
删除所有代码,复制以下代码段:
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:scaleType="fitCenter" />
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="top"
android:textSize="24sp"
android:textColor="@android:color/holo_red_light" />
</FrameLayout>
然后右键java里的com.example.工程名 文件夹,New->Java Class。本文新建的类名是CIfarClassed,类内代码:
package com.example.工程名;
public class CifarClassed {
public static String[] IMAGENET_CLASSES = new String[]{
"ddd",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
};
}
最后打开java->com.example.工程名->MainActivity,删除原代码,用以下代码替代:
package com.example.dnna;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
import com.example.dnna.CifarClassed;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module_ori = null;
Module module_pruned = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("x.png"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module_ori = Module.load(assetFilePath(this, "model_ori.pt"));
module_pruned = Module.load(assetFilePath(this, "model——pruned.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
long startTime_ori = System.currentTimeMillis();
final Tensor outputTensor_ori = module_ori.forward(IValue.from(inputTensor)).toTensor();
long endTime_ori = System.currentTimeMillis();
long InferenceTimeOri=endTime_ori - startTime_ori;
long startTime_pruned = System.currentTimeMillis();
final Tensor outputTensor_pruned = module_pruned.forward(IValue.from(inputTensor)).toTensor();
long endTime_pruned = System.currentTimeMillis();
long InferenceTimePruned=endTime_pruned - startTime_pruned;
// getting tensor content as java array of floats
final float[] scores = outputTensor_ori.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
System.out.println(maxScoreIdx);
String className = CifarClassed.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
String tex="推理结果:"+className+"n原始模型推理时间:"+InferenceTimeOri+"ms"+"n剪枝模型推理时间:"+InferenceTimePruned+"ms";
textView.setText(tex);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
四、结语
本文的主要流程是:
本人目前希望提升自己的博客撰写水平,如读者在实现过程中遇到困难,或在阅读本文时感到困惑,欢迎留言或添加我的QQ:1106295085。我将在周日下午回复,并积极修改本文。
原文地址:https://blog.csdn.net/qq_39068200/article/details/129231207
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_27832.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!