本站消息

站长简介/公众号


站长简介:逗比程序员,理工宅男,前每日优鲜python全栈开发工程师,利用周末时间开发出本站,欢迎关注我的微信公众号:幽默盒子,一个专注于搞笑,分享快乐的公众号

  价值13000svip视频教程,python大神匠心打造,零基础python开发工程师视频教程全套,基础+进阶+项目实战,包含课件和源码

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2020-10(57)

2020-11(18)

tensorflow实现quantization-aware training(伪量化,fake quantization)

发布于2019-08-07 12:34     阅读(2742)     评论(0)     点赞(1)     收藏(0)



版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lishanlu136/article/details/88872266

前面一篇文章讲模型优化的时候有讲到量化模型,但那只是量化权重,在实际计算的时候还是会反量化回去,用float32位计算,没有进行实际意义上的定点运算。今天讲的这个方式是可以部署在移动端进行定点运算的,乘现在网上关于这方面资料很少,赶紧写一篇,求赞呀~~~

源代码位置:tensorflow/contrib/quantize/
github参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize
tensorflow实例参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands

为啥叫伪量化?

因为它只是通过在训练时向某些能识别的操作中加入fake_quantization_node,用于统计该节点的最大值,最小值。这里统计的最大值,最小值用于后面用toco工具完全量化操作,从而减小量化操作带来的精度损失。

注意: 某些网络的某些特殊操作目前还不支持自动向图中加入fake_quantization_node统计最大,最小值,需要自己手动加入节点统计,统计得不准会带来精度大大地下降,慎用,(如果有些节点在用toco转换的过程需要用到最大值最小值,而模型在训练过程中又没有插入fake_quantization_node自动统计,它会提示你需要指定默认的最大值,最小值。)。

具体步骤

第一步: 在train.py中的loss之后,train_op之前加入tf.contrib.quantize.create_training_graph(input_graph=tf.get_default_graph(), quant_delay=20000)
训练模型,保存ckpt文件。
注意,这里的quant_delay是训练迭代多少次后,网络开始做量化统计最大值,最小值,并用8bit做反向传播更新梯度。
如果你之前有训练好一个完整的模型,可直接加载这个模型进来做微调,这时可设置quant_delay=0,在定义saver的时候,
用saver = tf.train.Saver(tf.global_variables())或者saver=tf.train.Saver()较为保险,不然后面在freeze.py中加载模型进去会说有些节点没有权重初始化。(意思是说图中的有些节点,ckpt中没有保存参数,我就碰到过这种情况)

第二步: 在freeze.py中,构建你的inference_graph,加入tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()),restore之前训练保存的ckpt文件,冻结生成pb文件。

第三步: 把pb文件转换成tflite文件,我这里是调用的tensorflow提供的python API,代码如下:

import tensorflow as tf
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
graph_def_file = "train_results/eval_model.pb"
input_arrays = ["input"]
output_arrays = ["softmax"]
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
converter.quantized_input_stats = {input_arrays[0]: (73.0, 10.00667)}    # mean, std_dev,需要自己从训练集(增强后,输入网络之前的)统计出来
tflite_model = converter.convert()
open("train_results/freeze_models/converted_model.tflite", "wb").write(tflite_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

在这里插入图片描述
最后能生成converted_model.tflite文件,大小约为转换之前的eval_model.pb的1/4左右。

第四步 测试converted_model.tflite

import numpy as np
import tensorflow as tf
import scipy
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="train_results/freeze_models/converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

image_origin = scipy.misc.imread("src/151105230861_0_76.536.jpg", mode='RGB')   
image_q = (image_origin + 7.288) * 10.006671114      #这些参数通过打印input_details可以看到
image_ = np.array([image_q.astype('uint8')])

print(image_.shape)
print(type(image_))

interpreter.set_tensor(input_details[0]['index'], image_)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(output_data.shape, type(output_data))

def f(x):
    return 0.0078125*(x-128)        # 这些参数也是在output_details里可以看到

output_data_new = map(f, output_data[0])    #转换后output_data_new为浮点数,你可以和没量化的模型输出对比一下相似度
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

写在后面的话

可通过netron查看模型文件的结构。
我的eval_model.pb模型结构截图
从图中可以看到有很多FakeQuantWithMinMaxVars节点,并还有相应的最大值,最小值,那都是模型在训练时,自动统计的。
我的tflite模型结构截图:
从图中可以看到输入输出节点都有量化前后的转换参数及浮点数的范围。
最后补充
我之前用pb转tflite是用toco工具的,但始终报错,在这里折腾了几周,我的是tensorflow1.12版本的,报错信息如下:
这个错误信息应该是说不支持pooling操作???
最后偶然看到tensorflow社区里的帖子,居然还有python的API接口,遂用api试一下转tflite,于是成功了,但测试的精度不咋地,那是因为我的均值和方差没有统计。后面怎么统计的呢,我是在训练过程中的图片input节点那里插入了两个节点:

image_max = tf.reduce_max(image_batch, name='image_max')
image_min = tf.reduce_min(image_batch, name='image_min')
  • 1
  • 2

用于统计图片的最大值和最小值,然后mean=255min/(min-max),std_dev=255/max-min

踩过的坑:
训练时,用slim写的网络,is_train不要用placeholder,用了会导致某些节点不会插入fake_quantized_node,slim.fully_connected好像不会自动插入fake_quantized_node,或者说slim.fully_connected加了batchnorm不会自动插入fake_quantized_node统计最大值最小值。

参考
tensorflowLite的量化使用问题,帖子很长,慢慢看吧






所属网站分类: 技术文章 > 博客

作者:大将军

链接:https://www.pythonheidong.com/blog/article/10855/3c006a1dd7107cf2e2de/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

1 0
收藏该文
已收藏

评论内容:(最多支持255个字符)