TensorFlow Lite android部署介绍
1、简要介绍
TensorFlow Lite是TensorFlow在移动和嵌入式设备上的轻量级解决方案,目前只能用于预测,还不能进行训练。TensorFLow Lite针对移动和嵌入设备开发,具有如下三个特点:
轻量
跨平台
快速
目前TensorFlow Lite已经支持android、iOS、Raspberry等设备,本章会基于Android设备上的部署方法进行讲解,内容包括模型保存、转换和部署。
2、模型保存
我们以keras模型训练和保存为例进行讲解,如下是keras官方的mnist模型训练样例。
'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12 epochs
(there is still a lot of margin for parameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''
from__future__importprint_function
importkeras
fromkeras.datasetsimportmnist
fromkeras.modelsimportSequential
fromkeras.layersimportDense,Dropout,Flatten
fromkeras.layersimportConv2D,MaxPooling2D
fromkerasimportbackendasK
batch_size=128
num_classes=10
epochs=12
# input image dimensions
img_rows,img_cols=28,28
# the data, split between train and test sets
(x_train,y_train), (x_test,y_test) =mnist.load_data()
ifK.image_data_format() =='channels_first':
x_train=x_train.reshape(x_train.shape[],1,img_rows,img_cols)
x_test=x_test.reshape(x_test.shape[],1,img_rows,img_cols)
input_shape= (1,img_rows,img_cols)
else:
x_train=x_train.reshape(x_train.shape[],img_rows,img_cols,1)
x_test=x_test.reshape(x_test.shape[],img_rows,img_cols,1)
input_shape= (img_rows,img_cols,1)
x_train=x_train.astype('float32')
x_test=x_test.astype('float32')
x_train/=255
x_test/=255
print('x_train shape:',x_train.shape)
print(x_train.shape[],'train samples')
print(x_test.shape[],'test samples')
# convert class vectors to binary class matrices
y_train=keras.utils.to_categorical(y_train,num_classes)
y_test=keras.utils.to_categorical(y_test,num_classes)
model=Sequential()
model.add(Conv2D(32,kernel_size=(3,3),
activation='relu',
input_shape=input_shape))
model.add(Conv2D(64, (3,3),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes,activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.fit(x_train,y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test,y_test))
score=model.evaluate(x_test,y_test,verbose=)
print('Test loss:',score[])
print('Test accuracy:',score[1])
创建mnist_cnn.py文件,将以上内容拷贝进去,并在最后加上如下一行代码:
model.save('mnist_cnn.h5')
在终端中执行mnist_cnn.py文件,如下:
pythonmnist_cnn.py
注:该过程需要连接网络获取mnist.npz文件(https://s3.amazonaws.com/img-datasets/mnist.npz),会被保存到$HOME/.keras/datasets/。如果网络连接存在问题,可以通过其他方式获取mnist.npz后,直接保存到该目录即可。
执行过程会比较久,执行结束后,会产生在当前目录产生文件(HDF5格式),就是keras训练后模型,其中已经包含了训练后的模型结构和权重等信息。
3、模型转换
不能直接在移动端部署,因为模型大小和运行效率比较低,最终需要通过工具转化为Flat Buffer格式的模型。
谷歌提供了多种转换方式:
tflight_convert:>= TensorFlow 1.9,本次讲这个
TOCO:>= TensorFlow 1.7
通过代码转换
tflight_convert跟tensorflow是一起下载的,笔者通过brew安装的python,pip安装tf-nightly后tflight_convert路径如下:
/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin
实际上,应该是/usr/local/Cellar/python/3.6.5_1/Frameworks/Python.framework/Versions/3.6/bin,但是软链接到了如上路径。如果命令行不能执行到tflight_convert,则在~/.bash_profile(macOS)或~/.bashrc(Linux)添加如下环境变量:
exportPATH="/usr/local/opt/python/Frameworks/Python.framework/Versions/3.6/bin:$PATH"
然后执行
source~/.bash_profile
或
source~/.bashrc
在命令执行
tflight_convert-h
输出结果如下,则证明安装配置成功。
usage: tflite_convert [-h]--output_fileOUTPUT_FILE
(--graph_def_file GRAPH_DEF_FILE |--saved_model_dirSAVED_MODEL_DIR |--keras_model_fileKERAS_MODEL_FILE)
[--output_format ]
[--inference_type ]
[--inference_input_type ]
[--input_arrays INPUT_ARRAYS]
下面我们开始转换模型,具体命令如下:
tflite_convert--keras_model_file=./mnist_cnn.h5--output_file=./mnist_cnn.tflite
到此,我们已经得到一个可以运行的TensorFlow Lite模型了,即。
注:这里只介绍了keras HDF5格式模型的转换,其他模型转换建议参考:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/tflite_convert/cmdline_examples.md
4、Android部署
现在开始在Android环境部署,对于国内的读者,需要先给Android Studio配置proxy,因为gradle编译环境需要获取相应的资源,请大家自行解决,这里不再赘述。
4.1 配置app/build.gradle
新建一个Android Project,打开app/build.gradle添加如下信息
android{
aaptOptions{
noCompress"tflite"
}
}
repositories{
maven{
url'https://google.bintray.com/tensorflow'
}
}
dependencies{
implementation'org.tensorflow:tensorflow-lite:1.10.0'
}
其中,
1、aaptOptions设置tflite文件不压缩,确保后面tflite文件可以被Interpreter正确加载。
2、org.tensorflow:tensorflow-lite的最新版本号,可以在这里查询https://bintray.com/google/tensorflow/tensorflow-lite,目前最新的是1.10.0版本。
设置好后,sync和build整个工程,如果build成功说明,配置成功。
4.2 添加tflite文件到assets文件夹
在app目录先新建assets目录,并将文件保存到assets目录。重新编译apk,检查新编译出来的apk的assets文件夹是否有文件。
使用apk analyzer查看新编译出来的apk,存在如下目录即编译打包成功。
assets
|__mnist_cnn.tflite
4.3 加载模型
使用如下函数将文件加载到memory-map中,作为Interpreter实例化的输入。
privatestaticfinalStringMODEL_PATH="mnist_cnn.tflite";
/** Memory-map the model file in Assets. */
privateMappedByteBufferloadModelFile(Activityactivity)throwsIOException{
AssetFileDescriptorfileDescriptor=activity.getAssets().openFd(MODEL_PATH);
FileInputStreaminputStream=newFileInputStream(fileDescriptor.getFileDescriptor());
FileChannelfileChannel=inputStream.getChannel();
longstartOffset=fileDescriptor.getStartOffset();
longdeclaredLength=fileDescriptor.getDeclaredLength();
returnf ileChannel.map(FileChannel.MapMode.READ_ONLY,startOffset,declaredLength);
}
实例化Interpreter,其中this为当前acitivity
tflite=newInterpreter(loadModelFile(this));
4.4 运行输入
我们使用mnist test测试集中的某张图片作为输入,mnist图像大小28*28,单像素。这样我们输入的数据需要设置成如下格式。
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
privateByteBufferimgData=null;
privatestaticfinalintDIM_BATCH_SIZE=1;
privatestaticfinalintDIM_PIXEL_SIZE=1;
privatestaticfinalintDIM_IMG_WIDTH=28;
privatestaticfinalintDIM_IMG_HEIGHT=28;
protectedvoidonCreate() {
imgData=ByteBuffer.allocatedirect(
4*DIM_BATCH_SIZE*DIM_IMG_WIDTH*DIM_IMG_HEIGHT*DIM_PIXEL_SIZE);