一、tf.keras
tf.keras是TensorFlow 2.0的高阶API接口,为TensorFlow的代码提供了新的风格和设计模式,大大提升了TF代码的简洁性和复用性,官方也推荐使用tf.keras来进行模型设计和开发。
1.1 tf.keras中常用模块
如下表所示:
1.2 常用方法
深度学习实现的主要流程:
1.数据获取,
2 数据处理,
3 模型创建与训练,
4 模型测试与评估,
5.模型预测
导入tf.keras
使用 tf.keras
,首先需要在代码开始时导入tf.keras
import tensorflow as tf
from tensorflow import keras
数据输入
对于小的数据集,可以直接使用numpy格式的数据进行训练、评估模型,对于大型数据集或者要进行跨设备训练时使用tf.data.datasets来进行数据输入。
模型构建
- 简单模型使用Sequential进行构建
- 复杂模型使用函数式编程来构建
- 自定义layers
训练与评估
- 配置训练过程
# 配置优化方法,损失函数和评价指标
model.compile(optimizer=tf.train.AdamOptimizer(0.001),loss='categorical_crossentropy',metrics=['accuracy'])
- 模型训练
# 指明训练数据集,训练epoch,批次大小和验证集数据
model.fit/fit_generator(dataset, epochs=10, batch_size=3,validation_data=val_dataset,)
- 模型评估
# 指明评估数据集和批次大小
model.evaluate(x, y, batch_size=32)
- 模型预测
# 对新的样本进行预测
model.predict(x, batch_size=32)
回调函数(callbacks)
回调函数用在模型训练过程中,来控制模型训练行为,可以自定义回调函数,也可使用tf.keras.callbacks 内置的 callback :
ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:动态改变学习速率。 EarlyStopping:当验证集上的性能不再提高时,终止训练。 TensorBoard:使用 TensorBoard 监测模型的状态。
模型的保存和恢复
- 只保存参数
# 只保存模型的权重
model.save_weights('./my_model')
# 加载模型的权重
model.load_weights('my_model')
- 保存整个模型
# 保存模型架构与权重在h5文件中
model.save('my_model.h5')
# 加载模型:包括架构和对应的权重
model = keras.models.load_model('my_model.h5')