+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2020-03(62)

2020-04(60)

2020-05(24)

2020-06(39)

2020-07(23)

荐深度神经网络学习笔记----Keras框架用自己的图像数据构建网络模型,用Tensorboard显示 loss,acc,val_loss, val_acc

发布于2020-05-29 22:43     阅读(860)     评论(0)     点赞(12)     收藏(2)


0

1

2

3

4

5

本文是一篇学习笔记,记录用keras框架训练用于自己的数据集的图像分类网络,并使用Tensorboard 将训练过程可视化.

把tensorboard 加入训练的代码中,

tbCallBack = TensorBoard(log_dir="./model", histogram_freq=1, write_grads=True)
history = model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size),
								steps_per_epoch=max(1, num_train // batch_size),
								validation_data=generate_arrays_from_file(lines[num_train:], batch_size),
								validation_steps=max(1, num_val // batch_size),
								epochs=100,
                                initial_epoch=0,
                                callbacks=[tbCallBack])

出现报错:

ValueError: An operation has “None” for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

ValueError: If printing histograms, validation_data must be provided, and cannot be a generator.

把参数histogram_freq=1更改为histogram_freq=0。(参数histogram_freq表示对于模型中各个层计算激活值和模型权重直方图的频率(训练轮数中)。 如果设置成 0 ,直方图不会被计算。)
因为我需要在tensorboard中记录loss和acc的变化曲线,所以把histogram_freq=0设置成了0,我知道我并没有从根本上解决这个问题,恳请指教。

https://blog.csdn.net/w5688414/article/details/89042489


1、训练过程

把所有待训练的数据保存在data路径下。
图像命名格式为:class_name.class_num.png , 用.作为分隔符。

1、编辑index.txt文件
一共分为n类
0;class_name_1
1;class_name_2
2;class_name_3
3;class_name_4
……
n-1;class_name_n

2、生成训练数据对应标签的.txt文件

import os 

with open('.txt文件的保存路径','w') as f:
    after_generate = os.listdir("训练数据的保存路径")
    for image in after_generate:
        if image.split(".")[0] == 'class_name_1':
            f.write(image + ";" + "0" + "\n")
        elif image.split(".")[0] == 'class_name_2':
            f.write(image + ";" + "1" + "\n")
        elif image.split(".")[0] == 'class_name_3':
            f.write(image + ";" + "2" + "\n")
        else:
            f.write(image + ";" + "3" + "\n")

生成的.txt文件结果如图:
在这里插入图片描述
3、keras中用于模型训练的函数
https://blog.csdn.net/learning_tortosie/article/details/85243310
https://blog.csdn.net/LuYi_WeiLin/article/details/88555813
https://blog.csdn.net/qq_32951799/article/details/82918098

  • .fit函数
  • .fit_generator函数
  • train_on batch函数

(1).fit函数
对于训练数据较少,模型简单而且不需要数据增强的训练任务,可以使用.fit函数。

model.fit(trainX, trainY, batch_size=32, epochs=50)

trainX表示用于训练的数据集
trainY表示训练集的标签
batch_size=32表示每一个batch有32张图片输入到网络中
epochs=50表示训练过程迭代50次

显然,小规模的数据集并不能满足所有的训练任务。(并不是所有的任务都能用Mnist手写数字的数据集来完成)。通常我们需要采用数据增强来提高模型的泛化能力,例如旋转、缩放、改变宽高、剪切、对称、填充……:

	width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
	horizontal_flip=True, fill_mode="nearest") ```

(2).fit_generator函数
那么.fit_generator函数为什么就能够进行数据增强呢?

数据数据增强后,用于训练的数据集将是不断变化的。如果把.fit函数看作是把所有的数据集一次全都输入到网络中,那么.fit_generator函数可以看成将用于训练的数据集按照批次,分批的放入网络中。

在代码定义了一个数据生成器,用来将一定数量(batch_size)的训练数据输入到网络中,更新网络中的权重。
特别地,.fit_generator函数中有一个参数steps_per_epoch

steps_per_epoch=max(1, num_train // batch_size)

数据生成器不断地输出一个又一个大小为batch_size 的数据,每次取一个batch_size的数据,当读取数据的次数达到steps_per_epoch则进入到下一个epoch。

keras中数据生成器是无线循环的, 永远不会返回和退出,使用steps_per_epoch这一参数能够定位当前进行到哪个epoch

(3) Keras train_on_batch函数

对于寻求对Keras模型进行精细控制( finest-grained control)的深度学习实践者,使用 Keras train_on_batch函数

Keras train_on_batch函数接受单批数据,执行反向传播,然后更新模型参数。batch_size可以是任意大小,不需要指定batch_size的大小。


我自己的训练模型的代码使用的是.fit_generator函数

2、用Tensorboard进行可视化

之前写过了,代码就是这几行:

tbCallBack = TensorBoard(log_dir="./model", histogram_freq=0, write_grads=True)
history = model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size),
                        steps_per_epoch=max(1, num_train // batch_size),
                        validation_data=generate_arrays_from_file(lines[num_train:], batch_size),
                        validation_steps=max(1, num_val // batch_size),
                        epochs=100,
                        initial_epoch=0,
                        callbacks=[tbCallBack])

https://blog.csdn.net/weixin_44791964/article/details/105002793
https://blog.csdn.net/qq_27825451/article/details/90229983

keras使用tensorboard是通过回调函数来实现的
Tensorboard的参数有很多,主要的7个参数分别为

  • log_dir: 日志文件保存的路径
  • histogram_freq: 计算模型中各个层计算激活值和模型权重直方图的频率(训练轮数中)
  • write_graph: 布尔值 是否在 TensorBoard 中可视化图像
  • write_grads: 在histogram_freq 大于0前提下,是否在 TensorBoard 中可视化梯度值直方图
  • batch_size: 计算直方图时传入神经元网络batch的大小
  • write_images: 布尔值 是否在 TensorBoard 中将模型权重以图片可视化,如果设置为True,日志文件会变得非常大。
  • update_freq: ‘batch’ 或 ‘epoch’ 或 整数(默认值是epoch)。也就是按照训练过程中的哪一个频率来保存tensorboard文件。当使用 ‘batch’ 时,在每个 batch 之后将损失和评估值写入到 TensorBoard 中。同样的情况应用到 ‘epoch’ 中。如果使用整数,例如 10000,这个回调会在每 10000 个样本之后将损失和评估值写入到 TensorBoard 中。注意,频繁地写入到 TensorBoard 会减缓训练。
    参数的默认值:
log_dir='./logs',  
histogram_freq=0,
batch_size=32,
write_graph=True,  
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None,
update_freq='epoch'

tensorboard函数中指定的logdir参数的路径就是日志文件保存的路径,tensorboard就是利用保存在日志文件中的信息来生成loss和acc的变化曲线。
启动tensorboard的步骤:

  1. 打开cmd运行,激活tensorflow:输入activate tensorflow
  2. 进入当前项目所在的盘符,例如E盘,输入: E:
  3. 复制日志文件所在路径的上一层的路径,例如日志文件保存在和代码同级目录的model中,那么不要打开model,只需要复制到model的上一层的路径!
  4. 在cmd中输入cd +复制的路径,enter
  5. 继续在cmd中输入tensorboard --logdir= ‘保存日志文件的那个文件夹名称’,也就是说这个文件夹名称和复制的那个路径连接在一起就是日志文件的完整保存路径。enter
  6. 得到了一个网址,不出意外,这个网址在google浏览器是打不开的,输入网址http://localhost:6006。

既然得到的网址也不用,为啥还费这劲呢。
只有cmd一直打开,才能看到tensorboard的结果哟!
结果如图所示:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

至此,利用tensorboard模型训练过程可视化完成。


感谢:
https://blog.csdn.net/w5688414/article/details/89042489
https://blog.csdn.net/learning_tortosie/article/details/85243310
https://blog.csdn.net/LuYi_WeiLin/article/details/88555813
https://blog.csdn.net/qq_32951799/article/details/82918098
https://blog.csdn.net/weixin_44791964/article/details/105002793
https://blog.csdn.net/qq_27825451/article/details/90229983
谢谢你的辛勤耕耘,让我受益匪浅

原文链接:https://blog.csdn.net/weixin_43227526/article/details/106373899

0

1

2

3

4



所属网站分类: 技术文章 > 博客

作者:9384vfnv

链接: https://www.pythonheidong.com/blog/article/397505/1ca0c6de6a655b0f4db6/

来源: python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

12 0
收藏该文
已收藏

评论内容:(最多支持255个字符)