+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2019-07(2)

2019-08(106)

2019-09(110)

2019-10(14)

2019-11(8)

TensorFlow中的compile和fit操作,简化神经网络模型代码

发布于2020-11-09 19:33     阅读(952)     评论(0)     点赞(3)     收藏(1)


0

1

2

3

4

5

6

7


前言

compile和fit是TensorFlow科学计算库中非常便捷是用来构建神经网络模型的API接口,它实现了模型数据的流向的定义,参数的更新,测试方法等等,极大的简化了代码,以下两段代码:第一段是不使用compile和fit函数实现mnist手写数字识别问题,第二段代码使用是使用compile和fit函数实现mnist手写数字识别问题,两者作为对比,方便理解。
提示:以下是本篇文章正文内容,下面案例可供参考

一、

# 前言
这是一个可以真正跑下来的全连接神经网络识别手写数字问题的代码哟,又不懂的语句或者逻辑,欢迎评论区留言
# 代码

```python
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,optimizers,datasets

#数据预处理
(x,y),(x_val,y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x,dtype=tf.float32)/255.
y = tf.convert_to_tensor(y,dtype=tf.int32)
y = tf.one_hot(y,depth=10)
print(x.shape,y.shape)
train_dateset = tf.data.Dataset.from_tensor_slices((x,y))
train_dateset = train_dateset.batch(200)


#建立神经网络模型
model = keras.Sequential(
    [
        layers.Dense(512,activation='relu'),
        layers.Dense(256,activation='relu'),
        layers.Dense(128,activation='relu'),
        layers.Dense(64,activation='relu'),
        layers.Dense(10)
    ]
)

optimizers = optimizers.SGD(learning_rate=0.001)

#定义训练流程
def train_epoch(epoch):

    for step,(x,y) in enumerate(train_dateset):

        with tf.GradientTape() as tape:

            x = tf.reshape(x,(-1,28*28))

            out = model(x)

            loss = tf.reduce_sum(tf.square(out-y))/x.shape[0]

        grads = tape.gradient(loss,model.trainable_variables)

        optimizers.apply_gradients(zip(grads,model.trainable_variables))

        if step%100 == 0:
            print(epoch,step,'loss',loss.numpy())

    #运行
def train():
        for epoch in range(30):
            train_epoch(epoch)



if __name__ == '__main__':
    train()

# 二、使用步骤
## 1.引入库


<font color=#999AAA >代码如下(示例):



```c
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import  ssl
ssl._create_default_https_context = ssl._create_unverified_context

代码如下(示例):

import  tensorflow as tf
from    tensorflow.keras import datasets, layers, optimizers, Sequential, metrics


def preprocess(x, y):
    """
    x is a simple image, not a batch
    """
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28*28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x,y


batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())



db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz) 

sample = next(iter(db))
print(sample[0].shape, sample[1].shape)


network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()




network.compile(optimizer=optimizers.Adam(lr=0.01),
		loss=tf.losses.CategoricalCrossentropy(from_logits=True),
		metrics=['accuracy']
	)

network.fit(db, epochs=5, validation_data=ds_val, validation_freq=2)
 
network.evaluate(ds_val)

sample = next(iter(ds_val))
x = sample[0]
y = sample[1] # one-hot
pred = network.predict(x) # [b, 10]
# convert back to number 
y = tf.argmax(y, axis=1)
pred = tf.argmax(pred, axis=1)

print(pred)
print(y)


总结

提示:这里对文章进行总结:
例如:以上就是今天要讲的内容,本文仅仅简单介绍了compile和fit的使用,而这两个API接口能使我们快速便捷地处理数据的函数和方法。

0

1

2

3

4

5

6

7

8

9



所属网站分类: 技术文章 > 博客

作者:9384vfnv

链接: https://www.pythonheidong.com/blog/article/612085/1bc5c615be74beed5a68/

来源: python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

3 0
收藏该文
已收藏

评论内容:(最多支持255个字符)