网站微信认证费用,南京网站设计制作,天津招标信息网官网,鞍山是哪个省哪个市#x1f368; 本文为#x1f517;365天深度学习训练营中的学习记录博客#x1f356; 原作者#xff1a;K同学啊
目录 一、导入数据并检查
二、配置数据集
三、数据可视化
四、构建模型
五、训练模型
六、模型对比评估
七、总结 一、导入数据并检查
import pathlib,… 本文为365天深度学习训练营中的学习记录博客 原作者K同学啊
目录 一、导入数据并检查
二、配置数据集
三、数据可视化
四、构建模型
五、训练模型
六、模型对比评估
七、总结 一、导入数据并检查
import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams[font.sans-serif] [SimHei] # 用来正常显示中文标签data_dir pathlib.Path(./T6)
image_count len(list(data_dir.glob(*/*)))batch_size 16
img_height 336
img_width 336关于image_dataset_from_directory()的详细介绍可以参考文章https://mtyjkh.blog.csdn.net/article/details/117018789train_ds tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split0.2,subsettraining,seed12,image_size(img_height, img_width),batch_sizebatch_size)
val_ds tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split0.2,subsetvalidation,seed12,image_size(img_height, img_width),batch_sizebatch_size) class_names train_ds.class_names
print(class_names) for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break 二、配置数据集
AUTOTUNE tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds (train_ds.cache().shuffle(1000).map(train_preprocessing) # 这里可以设置预处理函数
# .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_sizeAUTOTUNE)
)val_ds (val_ds.cache().shuffle(1000).map(train_preprocessing) # 这里可以设置预处理函数
# .batch(batch_size) # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_sizeAUTOTUNE)
)
三、数据可视化
plt.figure(figsize(10, 8)) # 图形的宽为10高为5
plt.suptitle(数据展示)for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show() 四、构建模型
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizeradam):# 加载预训练模型vgg16_base_model tf.keras.applications.vgg16.VGG16(weightsimagenet,include_topFalse,#不包含顶层的全连接层input_shape(img_width, img_height, 3),poolingavg)#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable False #将 trainable属性设置为 False 意味着在训练过程中这些层的权重不会更新X vgg16_base_model.outputX Dense(170, activationrelu)(X)X BatchNormalization()(X)X Dropout(0.5)(X)output Dense(len(class_names), activationsoftmax)(X)#神经元数量等于类别数vgg16_model Model(inputsvgg16_base_model.input, outputsoutput)vgg16_model.compile(optimizeroptimizer,losssparse_categorical_crossentropy,metrics[accuracy])return vgg16_modelmodel1 create_model(optimizertf.keras.optimizers.Adam())
model2 create_model(optimizertf.keras.optimizers.SGD())#随机梯度下降SGD优化器的
model2.summary() 五、训练模型
NO_EPOCHS 20history_model1 model1.fit(train_ds, epochsNO_EPOCHS, verbose1, validation_dataval_ds)
history_model2 model2.fit(train_ds, epochsNO_EPOCHS, verbose1, validation_dataval_ds)
六、模型对比评估
from matplotlib.ticker import MultipleLocator
plt.rcParams[savefig.dpi] 300 #图片像素
plt.rcParams[figure.dpi] 300 #分辨率acc1 history_model1.history[accuracy]
acc2 history_model2.history[accuracy]
val_acc1 history_model1.history[val_accuracy]
val_acc2 history_model2.history[val_accuracy]loss1 history_model1.history[loss]
loss2 history_model2.history[loss]
val_loss1 history_model1.history[val_loss]
val_loss2 history_model2.history[val_loss]epochs_range range(len(acc1))plt.figure(figsize(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, labelTraining Accuracy-Adam)
plt.plot(epochs_range, acc2, labelTraining Accuracy-SGD)
plt.plot(epochs_range, val_acc1, labelValidation Accuracy-Adam)
plt.plot(epochs_range, val_acc2, labelValidation Accuracy-SGD)
plt.legend(loclower right)
plt.title(Training and Validation Accuracy)
# 设置刻度间隔x轴每1一个刻度
ax plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, labelTraining Loss-Adam)
plt.plot(epochs_range, loss2, labelTraining Loss-SGD)
plt.plot(epochs_range, val_loss1, labelValidation Loss-Adam)
plt.plot(epochs_range, val_loss2, labelValidation Loss-SGD)
plt.legend(locupper right)
plt.title(Training and Validation Loss)# 设置刻度间隔x轴每1一个刻度
ax plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show() 可以看出在这个实例中Adam优化器的效果优于SGD优化器
七、总结 通过本次实验学会了比较不同优化器Adam和SGD在训练过程中的性能表现可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能在研究论文中可以通过这些优化方法可以提高工作量。