【干货】TensorFlow Lite Android部署介绍

2019-02-20 18:12发布

TensorFlow Lite android部署介绍

1、简要介绍

TensorFlow Lite是TensorFlow在移动和嵌入式设备上的轻量级解决方案,目前只能用于预测,还不能进行训练。TensorFLow Lite针对移动和嵌入设备开发,具有如下三个特点:

轻量

跨平台

快速

目前TensorFlow Lite已经支持android、iOS、Raspberry等设备,本章会基于Android设备上的部署方法进行讲解,内容包括模型保存、转换和部署。

2、模型保存

我们以keras模型训练和保存为例进行讲解,如下是keras官方的mnist模型训练样例。

'''Trains a simple convnet on the MNIST dataset.

Gets to 99.25% test accuracy after 12 epochs

(there is still a lot of margin for parameter tuning).

16 seconds per epoch on a GRID K520 GPU.

'''

from__future__importprint_function

importkeras

fromkeras.datasetsimportmnist

fromkeras.modelsimportSequential

fromkeras.layersimportDense,Dropout,Flatten

fromkeras.layersimportConv2D,MaxPooling2D

fromkerasimportbackendasK

batch_size=128

num_classes=10

epochs=12

# input image dimensions

img_rows,img_cols=28,28

# the data, split between train and test sets

(x_train,y_train), (x_test,y_test) =mnist.load_data()

ifK.image_data_format() =='channels_first':

x_train=x_train.reshape(x_train.shape[],1,img_rows,img_cols)

x_test=x_test.reshape(x_test.shape[],1,img_rows,img_cols)

input_shape= (1,img_rows,img_cols)

else:

x_train=x_train.reshape(x_train.shape[],img_rows,img_cols,1)

x_test=x_test.reshape(x_test.shape[],img_rows,img_cols,1)

input_shape= (img_rows,img_cols,1)

x_train=x_train.astype('float32')

x_test=x_test.astype('float32')

x_train/=255

x_test/=255

print('x_train shape:',x_train.shape)

print(x_train.shape[],'train samples')

print(x_test.shape[],'test samples')

# convert class vectors to binary class matrices

y_train=keras.utils.to_categorical(y_train,num_classes)

y_test=keras.utils.to_categorical(y_test,num_classes)

model=Sequential()

model.add(Conv2D(32,kernel_size=(3,3),

activation='relu',

input_shape=input_shape))

model.add(Conv2D(64, (3,3),activation='relu'))

model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(128,activation='relu'))

model.add(Dropout(0.5))

model.add(Dense(num_classes,activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,

optimizer=keras.optimizers.Adadelta(),

metrics=['accuracy'])

model.fit(x_train,y_train,

batch_size=batch_size,

epochs=epochs,

verbose=1,

validation_data=(x_test,y_test))

score=model.evaluate(x_test,y_test,verbose=)

print('Test loss:',score[])

print('Test accuracy:',score[1])

创建mnist_cnn.py文件,将以上内容拷贝进去,并在最后加上如下一行代码:

model.save('mnist_cnn.h5')

在终端中执行mnist_cnn.py文件,如下:

pythonmnist_cnn.py

注:该过程需要连接网络获取mnist.npz文件(https://s3.amazonaws.com/img-datasets/mnist.npz),会被保存到$HOME/.keras/datasets/。如果网络连接存在问题,可以通过其他方式获取mnist.npz后,直接保存到该目录即可。

执行过程会比较久,执行结束后,会产生在当前目录产生文件(HDF5格式),就是keras训练后模型,其中已经包含了训练后的模型结构和权重等信息。

3、模型转换

不能直接在移动端部署,因为模型大小和运行效率比较低,最终需要通过工具转化为Flat Buffer格式的模型。

谷歌提供了多种转换方式:

tflight_convert:>= TensorFlow 1.9,本次讲这个

TOCO:>= TensorFlow 1.7

通过代码转换

tflight_convert跟tensorflow是一起下载的,笔者通过brew安装的python,pip安装tf-nightly后tflight_convert路径如下:

/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin

实际上,应该是/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/bin,但是软链接到了如上路径。如果命令行不能执行到tflight_convert,则在~/.bash_profile(macOS)或~/.bashrc(Linux)添加如下环境变量:

exportPATH="/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin:$PATH"

然后执行

source~/.bash_profile

source~/.bashrc

在命令执行

tflight_convert-h

输出结果如下,则证明安装配置成功。

usage: tflite_convert [-h]--output_fileOUTPUT_FILE

(--graph_def_file GRAPH_DEF_FILE |--saved_model_dirSAVED_MODEL_DIR |--keras_model_fileKERAS_MODEL_FILE)

[--output_format ]

[--inference_type ]

[--inference_input_type ]

[--input_arrays INPUT_ARRAYS]

下面我们开始转换模型,具体命令如下:

tflite_convert--keras_model_file=./mnist_cnn.h5--output_file=./mnist_cnn.tflite

到此,我们已经得到一个可以运行的TensorFlow Lite模型了,即。

注:这里只介绍了keras HDF5格式模型的转换,其他模型转换建议参考:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md

4、Android部署

现在开始在Android环境部署,对于国内的读者,需要先给Android Studio配置proxy,因为gradle编译环境需要获取相应的资源,请大家自行解决,这里不再赘述。

4.1 配置app/build.gradle

新建一个Android Project,打开app/build.gradle添加如下信息

android{

aaptOptions{

noCompress"tflite"

}

}

repositories{

maven{

url'https://google.bintray.com/tensorflow'

}

}

dependencies{

implementation'org.tensorflow:tensorflow-lite:1.10.0'

}

其中,

1、aaptOptions设置tflite文件不压缩,确保后面tflite文件可以被Interpreter正确加载。

2、org.tensorflow:tensorflow-lite的最新版本号,可以在这里查询https://bintray.com/google/tensorflow/tensorflow-lite,目前最新的是1.10.0版本。

设置好后,sync和build整个工程,如果build成功说明,配置成功。

4.2 添加tflite文件到assets文件夹

在app目录先新建assets目录,并将文件保存到assets目录。重新编译apk,检查新编译出来的apk的assets文件夹是否有文件。

使用apk analyzer查看新编译出来的apk,存在如下目录即编译打包成功。

assets

|__mnist_cnn.tflite

4.3 加载模型

使用如下函数将文件加载到memory-map中,作为Interpreter实例化的输入。

privatestaticfinalStringMODEL_PATH="mnist_cnn.tflite";

/** Memory-map the model file in Assets. */

privateMappedByteBufferloadModelFile(Activityactivity)throwsIOException{

AssetFileDescriptorfileDescriptor=activity.getAssets().openFd(MODEL_PATH);

FileInputStreaminputStream=newFileInputStream(fileDescriptor.getFileDescriptor());

FileChannelfileChannel=inputStream.getChannel();

longstartOffset=fileDescriptor.getStartOffset();

longdeclaredLength=fileDescriptor.getDeclaredLength();

returnf ileChannel.map(FileChannel.MapMode.READ_ONLY,startOffset,declaredLength);

}

实例化Interpreter,其中this为当前acitivity

tflite=newInterpreter(loadModelFile(this));

4.4 运行输入

我们使用mnist test测试集中的某张图片作为输入,mnist图像大小28*28,单像素。这样我们输入的数据需要设置成如下格式。

/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */

privateByteBufferimgData=null;

privatestaticfinalintDIM_BATCH_SIZE=1;

privatestaticfinalintDIM_PIXEL_SIZE=1;

privatestaticfinalintDIM_IMG_WIDTH=28;

privatestaticfinalintDIM_IMG_HEIGHT=28;

protectedvoidonCreate() {

imgData=ByteBuffer.allocatedirect(

4*DIM_BATCH_SIZE*DIM_IMG_WIDTH*DIM_IMG_HEIGHT*DIM_PIXEL_SIZE);