• 正文
    • 01、常見(jiàn)的模型格式
    • 02、建立模型并進(jìn)行轉(zhuǎn)換
    • 03、運(yùn)行驗(yàn)證環(huán)節(jié)
  • 相關(guān)推薦
申請(qǐng)入駐 產(chǎn)業(yè)圖譜

研發(fā)干貨丨OK1052-C開發(fā)板運(yùn)行 Tensorflow_lite模型

2021/04/01
302
加入交流群
掃碼加入
獲取工程師必備禮包
參與熱點(diǎn)資訊討論

tensorFlowLite 是一款TensorFlow用于移動(dòng)設(shè)備和嵌入式設(shè)備的輕量級(jí)解決方案。

我們知道TensorFlow可以在多個(gè)平臺(tái)上運(yùn)行,從機(jī)架式服務(wù)器到小型IoT設(shè)備。但是隨著近年來(lái)機(jī)器學(xué)習(xí)模型的廣泛使用,出現(xiàn)了在移動(dòng)和嵌入式設(shè)備上部署它們的需求,而TensorFlowLite 允許設(shè)備端的機(jī)器學(xué)習(xí)模型的低延遲推斷。

飛凌OK1052-C開發(fā)板以高性能處理器i.MXRT1052作為硬件軀體,再用tensorflow_lite武裝軟件大腦,也搭上了開往機(jī)器學(xué)習(xí),邊緣計(jì)算等朝陽(yáng)行業(yè)領(lǐng)域的快車。

▲OK1052-C開發(fā)板接口圖

 

OK1052-C的用戶資料SDK中,已經(jīng)有了關(guān)于tensorflow_lite使用的demo例程,但是這些例程使用的都是現(xiàn)成訓(xùn)練之后的實(shí)驗(yàn)?zāi)P?,并不適用于我們實(shí)際的應(yīng)用場(chǎng)景。我們?cè)趯?shí)際應(yīng)用項(xiàng)目中,必然是需要使用適合本項(xiàng)目的模型,網(wǎng)上現(xiàn)成模型資源下載或者自己訓(xùn)練模型,然后搭載自己的應(yīng)用程序,完成我們的應(yīng)用項(xiàng)目需求。

由于模型的訓(xùn)練需要算力的支持,通常我們?cè)?a class="article-link" target="_blank" href="/tag/%E8%AE%A1%E7%AE%97%E6%9C%BA/">計(jì)算機(jī)上進(jìn)行模型的訓(xùn)練,我們將訓(xùn)練好的模型稱為預(yù)訓(xùn)練模型。我們也可以在網(wǎng)上download預(yù)訓(xùn)練模型,然后通過(guò)格式轉(zhuǎn)換轉(zhuǎn)換為tflite格式的模型,也就是開發(fā)板可執(zhí)行的模型格式;也可以自己搭建計(jì)算網(wǎng)絡(luò),通過(guò)訓(xùn)練之后,形成預(yù)訓(xùn)練模型,再轉(zhuǎn)換為tflite格式,運(yùn)行到開發(fā)板。

今天我們通過(guò)一個(gè)簡(jiǎn)單的例子,來(lái)介紹怎么建立計(jì)算網(wǎng)絡(luò)和進(jìn)行模型的轉(zhuǎn)換。

01、常見(jiàn)的模型格式

我們常見(jiàn)的訓(xùn)練模型格式有.PB,cpkt,saveModel,H5等,若想將模型運(yùn)用于tensor flowlite,需要轉(zhuǎn)換為tflite格式,一般的轉(zhuǎn)換過(guò)程是:

由上圖可知,如果要將Checkpoints(cpkt)格式轉(zhuǎn)換為tflite,需經(jīng)過(guò)freeze_graph.py工具將cpkt格式模型轉(zhuǎn)換為Frozen GraphDef(.pb)格式,然后再經(jīng)過(guò)TFLite Converter轉(zhuǎn)換工具轉(zhuǎn)換為tflite。

02、建立模型并進(jìn)行轉(zhuǎn)換

現(xiàn)在通過(guò)例子介紹如何將模型轉(zhuǎn)換為tflite格式,首先我們先建立兩個(gè)網(wǎng)絡(luò)流圖,并分別生成.pb和cpkt格式的模型。

1)建立計(jì)算網(wǎng)絡(luò),并保存為.PB格式模型

# coding=UTF-8

import tensorflow as tf

import shutil

import os.path

from tensorflow.python.framework import graph_util

output_graph = "easy_model/add_model.pb"

#下面的過(guò)程你可以替換成CNN、RNN等你想做的訓(xùn)練過(guò)程,這里只是簡(jiǎn)單的一個(gè)計(jì)算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否則返回false

predictions = tf.add(_y, 10,name="predictions") #做一個(gè)加法運(yùn)算

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

print ("predictions :", sess.run(predictions, feed_dict={input_holder: [10.0]}))

graph_def = tf.get_default_graph().as_graph_def() #得到當(dāng)前的圖的 GraphDef 部分,通過(guò)這個(gè)部分就可以完成重輸入層到輸出層的計(jì)算過(guò)程

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定

sess,

graph_def,

["predictions"] #需要保存節(jié)點(diǎn)的名字

)

with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型

f.write(output_graph_def.SerializeToString()) # 序列化輸出

print("%d ops in the final graph." % len(output_graph_def.node))

print (predictions)

此計(jì)算網(wǎng)絡(luò)只是一個(gè)簡(jiǎn)單的數(shù)學(xué)運(yùn)算,不需要進(jìn)行訓(xùn)練,該運(yùn)算公式為:

predictions = (input_holder * W1) + B1 + 10

其中input_holder是網(wǎng)絡(luò)的輸入節(jié)點(diǎn),predictions是網(wǎng)絡(luò)的輸出節(jié)點(diǎn),W1,B1是兩個(gè)變量,分別被賦值為5.0,,1.0

程序運(yùn)行之后,會(huì)在easy_model/文件夾下生成add_model.pb模型文件。我們通過(guò)Netron軟件(Netron是一個(gè)很方便的軟件)來(lái)看一下add_model.pb模型文件中存儲(chǔ)的網(wǎng)絡(luò)圖及參數(shù)信息:

 

2)建立網(wǎng)絡(luò)圖,并保存為.ckpt格式模型:

# coding=UTF-8 支持中文編碼格式

import tensorflow as tf

import shutil

import os.path

MODEL_DIR = "easy_model"

MODEL_NAME = "model.ckpt"

if not tf.gfile.Exists(MODEL_DIR):

tf.gfile.MakeDirs(MODEL_DIR)

下面的過(guò)程你可以替換成CNN、RNN等你想做的訓(xùn)練過(guò)程,這里只是簡(jiǎn)單的一個(gè)計(jì)算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

#predictions = tf.greater(_y, 50, name="predictions")

predictions = tf.add(_y, 10,name="predictions")

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:

sess.run(init)

print ("predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}))

saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME))

print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node))

for op in tf.get_default_graph().get_operations():

print (op.name, op.values())

 

此網(wǎng)絡(luò)圖同上一節(jié)一樣是一個(gè)簡(jiǎn)單的數(shù)學(xué)運(yùn)算。

不同的是,此程序最后保存為ckpt格式的模型文件:

checkpoint :記錄目錄下所有模型文件列表

ckpt.data :保存模型中每個(gè)變量的取值

ckpt.meta :保存整個(gè)計(jì)算網(wǎng)絡(luò)圖的結(jié)構(gòu)

同樣可通過(guò)Netron軟件查看ckpt文件中存儲(chǔ)的網(wǎng)絡(luò)圖結(jié)構(gòu):

 

3)將生成的cpkt格式模型文件轉(zhuǎn)換為.pb文件:

import tensorflow as tf

from tensorflow.python.framework import graph_util

#from create_tf_record import *

resize_height = 100 # 指定圖片高度

resize_width = 100 # 指定圖片寬度

def freeze_graph(input_checkpoint, output_graph):

'''

:param input_checkpoint:

:param output_graph: PB 模型保存路徑

:return:

'''

# 檢查目錄下ckpt文件狀態(tài)是否可用

# checkpoint = tf.train.get_checkpoint_state(model_folder)

# 得ckpt文件路徑

# input_checkpoint = checkpoint.model_checkpoint_path

# 指定輸出的節(jié)點(diǎn)名稱,該節(jié)點(diǎn)名稱必須是元模型中存在的節(jié)點(diǎn)

output_node_names = "predictions"

saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph() # 獲得默認(rèn)的圖

input_graph_def = graph.as_graph_def() # 返回一個(gè)序列化的圖代表當(dāng)前的圖

with tf.Session() as sess:

saver.restore(sess, input_checkpoint) # 恢復(fù)圖并得到數(shù)據(jù)

# 模型持久化,將變量值固定

output_graph_def = graph_util.convert_variables_to_constants(

sess=sess,

# 等于:sess.graph_def

input_graph_def=input_graph_def,

# 如果有多個(gè)輸出節(jié)點(diǎn),以逗號(hào)隔開

output_node_names=output_node_names.split(","))

# 保存模型

with tf.gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString()) # 序列化輸出

# 得到當(dāng)前圖有幾個(gè)操作節(jié)點(diǎn)

print("%d ops in the final graph." % len(output_graph_def.node))

input_checkpoint='easy_model/model.ckpt'

# 輸出pb模型的路徑

out_pb_path="easy_model/frozen_model.pb"

freeze_graph(input_checkpoint,out_pb_path)

這里需要輸入輸出節(jié)點(diǎn)名稱output_node_names= "predictions",

程序運(yùn)行之后在easy_model文件夾下生成了frozen_model.pb文件,我們通過(guò)Netron軟件查看一下frozen_model.pb網(wǎng)絡(luò)圖,可以發(fā)現(xiàn)跟第一節(jié)中直接生成的.pb模型文件一樣的:

cpkt文件格式將模型保存為4個(gè)文件,pb文件格式為一個(gè)。ckpt模型持久化方式將圖結(jié)構(gòu)與權(quán)重參數(shù)分開保存,多了模型更多的細(xì)節(jié),適合模型訓(xùn)練階段;而pb持久化方式完成了從輸入到輸出的前向傳播,完成了端到端的形式,更適合離線使用。

 

4)最后,我們將.pb文件轉(zhuǎn)換為.tflite:

我們運(yùn)行此段代碼:

function showSnackbar() {

var $snackbar = $('#snackbar');

$snackbar.addClass('show');

setTimeout(() => {

$snackbar.removeClass('show');

}, 3000);

}

注意這里需要寫上輸入輸出節(jié)點(diǎn)名稱,這個(gè)在構(gòu)建網(wǎng)絡(luò)模型時(shí),已經(jīng)定義。

運(yùn)行之后,會(huì)報(bào)錯(cuò):

module 'tensorflow' has no attribute 'lite'

意思是目前的tensorflow版本不支持lite類,沒(méi)關(guān)系,我們重新安裝1.14版本的tensorflow即可成功運(yùn)行。最后生成tflite格式的模型文件:easy_model.lite。

03、運(yùn)行驗(yàn)證環(huán)節(jié)

將轉(zhuǎn)換后模型運(yùn)行到OK1052-C開發(fā)板進(jìn)行驗(yàn)證

首先需要將將easy_model.lite轉(zhuǎn)換為二進(jìn)制數(shù)組的.h文件easy_frozen_0.h,轉(zhuǎn)換完成之后,將其放入OK1052-C用戶資料的,SDK的middlewareeiqtensorflow-liteexampleslabel_image目錄下,

打開SDK中boardsevkbimxrt1050eiq_examplestensorflow_lite_label_image下工程,在文件label_iamge.cpp中做如下修改:


 

#include "board.h"

#include "pin_mux.h"

#include "clock_config.h"

#include "fsl_debug_console.h"

#include

#include

#include

#include "timer.h"

#include "tensorflow/lite/kernels/register.h"

#include "tensorflow/lite/model.h"

#include "tensorflow/lite/optional_debug_tools.h"

#include "tensorflow/lite/string_util.h"

//#include "Sine_mode.h"

//#include "add_model.h"

#include "easy_frozen_0.h"

int inference_count = 0;

// This is a small number so that it's easy to read the logs

const int kInferencesPerCycle = 30;

const float kXrange = 2.f * 3.14159265359f;

#define LOG(x) std::cout

void RunInference()

{

std::unique_ptr model;

std::unique_ptr interpreter;

model = tflite::FlatBufferModel::BuildFromBuffer(mobilenet_model, mobilenet_model_len);

if (!model) {

LOG(FATAL) << "Failed to load modelrn";

exit(-1);

}

model->error_reporter();

tflite::ops::builtin::BuiltinOpResolver resolver;

tflite::InterpreterBuilder(*model, resolver)(&interpreter);

if (!interpreter) {

LOG(FATAL) << "Failed to construct interpreterrn";

exit(-1);

}

float input = interpreter->inputs()[0];

if (interpreter->AllocateTensors() != kTfLiteOk) {

LOG(FATAL) << "Failed to allocate tensors!rn";

}

while(true)

{

// Calculate an x value to feed into the model. We compare the current

// inference_count to the number of inferences per cycle to determine

// our position within the range of possible x values the model was

// trained on, and use this to calculate a value.

float position = static_cast(inference_count) /

static_cast(kInferencesPerCycle);

float x_val = position * kXrange;

float* input_tensor_data = interpreter->typed_tensor(input);

*input_tensor_data = x_val;

// Delay_time(1000);

// Run inference, and report any error

TfLiteStatus invoke_status = interpreter->Invoke();

if (invoke_status != kTfLiteOk)

{

LOG(FATAL) << "Failed to invoke tflite!rn";

return;

}

// Read the predicted y value from the model's output tensor

float* y_val = interpreter->typed_output_tensor(0);

PRINTF("rn x_value: %f, y_value: %f rn", x_val, y_val[0]);

//PRINTF("rn x_value: %d, y_value: %d rn", (int)x_val, (int)y_val[0]);

// Increment the inference_counter, and reset it if we have reached

// the total number per cycle

inference_count += 1;

if (inference_count >= kInferencesPerCycle) inference_count = 0;

}

}

/*

* @brief Application entry point.

*/

int main(void)

{

/* Init board hardware */

BOARD_ConfigMPU();

BOARD_InitPins();

BOARD_InitDEBUG_UARTPins();

BOARD_BootClockRUN();

BOARD_InitDebugConsole();

NVIC_SetPriorityGrouping(3);

InitTimer();

std::cout << "The hello_world demo of TensorFlow Lite modelrn";

RunInference();

std::flush(std::cout);

for (;;) {}

}

此工程運(yùn)行之后打印信息如下:

可以看到,輸入的值x_value通過(guò)模型計(jì)算之后得到y(tǒng)_value

使用的就是計(jì)算網(wǎng)絡(luò)中的公式:

predictions = (input_holder * W1) + B1 + 10

由此可知,我們模型轉(zhuǎn)換成功。

飛凌嵌入式

飛凌嵌入式

保定飛凌嵌入式技術(shù)有限公司,創(chuàng)建于2006年,是一家專注嵌入式核心控制系統(tǒng)研發(fā)、設(shè)計(jì)和生產(chǎn)的高新技術(shù)企業(yè),是國(guó)內(nèi)較早專業(yè)從事嵌入式技術(shù)的企業(yè)之一。 經(jīng)過(guò)十幾年的發(fā)展與積累,公司擁有業(yè)內(nèi)優(yōu)秀的軟硬件研發(fā)團(tuán)隊(duì),在北京及保定建立兩大研發(fā)基地,在蘇州、深圳設(shè)有華東、華南技術(shù)服務(wù)中心,并在北美、歐洲以及亞太等其他國(guó)家和地區(qū)擁有國(guó)際業(yè)務(wù)網(wǎng)絡(luò)。公司研發(fā)的智能設(shè)備核心平臺(tái)廣泛應(yīng)用于物聯(lián)網(wǎng)、工控、軌道交通、醫(yī)療、電力、商業(yè)電子、智能家居、安防、機(jī)器人、環(huán)境監(jiān)測(cè)等諸多領(lǐng)域。

保定飛凌嵌入式技術(shù)有限公司,創(chuàng)建于2006年,是一家專注嵌入式核心控制系統(tǒng)研發(fā)、設(shè)計(jì)和生產(chǎn)的高新技術(shù)企業(yè),是國(guó)內(nèi)較早專業(yè)從事嵌入式技術(shù)的企業(yè)之一。 經(jīng)過(guò)十幾年的發(fā)展與積累,公司擁有業(yè)內(nèi)優(yōu)秀的軟硬件研發(fā)團(tuán)隊(duì),在北京及保定建立兩大研發(fā)基地,在蘇州、深圳設(shè)有華東、華南技術(shù)服務(wù)中心,并在北美、歐洲以及亞太等其他國(guó)家和地區(qū)擁有國(guó)際業(yè)務(wù)網(wǎng)絡(luò)。公司研發(fā)的智能設(shè)備核心平臺(tái)廣泛應(yīng)用于物聯(lián)網(wǎng)、工控、軌道交通、醫(yī)療、電力、商業(yè)電子、智能家居、安防、機(jī)器人、環(huán)境監(jiān)測(cè)等諸多領(lǐng)域。收起

查看更多

相關(guān)推薦

登錄即可解鎖
  • 海量技術(shù)文章
  • 設(shè)計(jì)資源下載
  • 產(chǎn)業(yè)鏈客戶資源
  • 寫文章/發(fā)需求
立即登錄

秉承專業(yè)態(tài)度,專注智能設(shè)備核心平臺(tái)研發(fā)與制造,以技術(shù)研發(fā)創(chuàng)新為主導(dǎo),以客戶實(shí)用化,產(chǎn)品化為目標(biāo),把握嵌入式行業(yè)的前沿發(fā)展需求,利用核心技術(shù)為客戶提供穩(wěn)定、可靠、功能優(yōu)異的高品質(zhì)產(chǎn)品。合作聯(lián)系:17713286011