如何建立网站做微商,低价网站建设多少钱,在线直播,qq空间刷赞推广网站#x1f368; 本文为#x1f517;365天深度学习训练营 中的学习记录博客#x1f356; 原作者#xff1a;K同学啊 目录 环境步骤环境设置数据准备工具方法模型设计模型训练模型效果展示 总结与心得体会 上周已经简单的了解了ACGAN的原理#xff0c;并且不经实践的编写了部分… 本文为365天深度学习训练营 中的学习记录博客 原作者K同学啊 目录 环境步骤环境设置数据准备工具方法模型设计模型训练模型效果展示 总结与心得体会 上周已经简单的了解了ACGAN的原理并且不经实践的编写了部分代码这周复现一下真正的ACGAN
环境
Pytorch 2.3.1cu121 Nvidia GTX 4090
步骤
环境设置
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as npdevice torch.device(cuda if torch.cuda.is_available() else cpu)# 全局参数
n_epochs 200
batch_size 64
lr 0.0002
b1 0.5
b2 0.999
n_cpu 8
latent_dim 100
n_classes 10
img_size 32
channels 1
sample_interval 400数据准备
# 创建中间采样图片的文件夹
import os
os.makedirs(images, exist_okTrue)
# 配置数据集
os.makedirs(data/mnist, exist_okTrue)
dataloader DataLoader(datasets.MNIST(data/mnist,trainTrue,downloadTrue,transformtransforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])]),),batch_sizebatch_size,shuffleTrue,
)
工具方法
# 权重初始化函数
def weights_init_normal(m):classname m.__class__.__name__if classname.find(Conv) ! -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find(BatchNorm2d) ! -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 日志函数 因为使用了jupyter notebook环境长时间的任务日志无法直接查看于是需要打印到文件
import logging
import sys
import datetimedef init_logger(filename, logger_name):brief:initialize logger that redirect info to a file just in case we lost connection to the notebookparams:filename: to which file should we log all the infologger_name: an alias to the logger# get current timestamptimestamp datetime.datetime.utcnow().strftime(%Y%m%d_%H-%M-%S)logging.basicConfig(levellogging.INFO, format[%(asctime)s] %(name)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s,handlers[logging.FileHandler(filenamefilename),logging.StreamHandler(sys.stdout)])# Testlogger logging.getLogger(logger_name)logger.info(### Init. Logger {} ###.format(logger_name))return logger# Initialize
my_logger init_logger(./ml_notebook.log, ml_logger)# 生成函数的结果保存
def sample_image(n_row, batches_done):保存从0到n_classes的生成数字的图像风格# 采样噪声z torch.randn((n_row**2, latent_dim), devicedevice)# 为n行生成标签从0到n_classeslabels torch.tensor([num for _ in range(n_row) for num in range(n_row)], devicedevice)gen_imgs generator(z, labels)save_image(gen_imgs.data.cpu(), images/%d.png % batches_done, nrown_row, normalizeTrue)模型设计
# 生成器
class Generator(nn.Module):def __init__(self):super().__init__()# 标签嵌入self.label_emb nn.Embedding(n_classes, latent_dim)# 计算上采样前的初始大小self.init_size img_size // 4# 第一层线性层self.l1 nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))# 卷积层self.conv_blocks nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor2),nn.Conv2d(128, 128, 3, stride1, padding1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplaceTrue),nn.Upsample(scale_factor2),nn.Conv2d(128, 64, 3, stride1, padding1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplaceTrue),nn.Conv2d(64, channels, 3, stride1, padding1),nn.Tanh(),)def forward(self, noise, labels):# 标签嵌入到噪声中gen_input torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out self.l1(gen_input)# 整形out out.view(out.shape[0], 128, self.init_size, self.init_size)# 卷积生成图像img self.conv_blocks(out)return img
# 判别器
class Discriminator(nn.Module):def __init__(self):super().__init__()# 判别器块生成函数def discriminator_block(in_filters, out_filters, bnTrue):返回每个判别器层block [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplaceTrue), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 卷积层self.conv_blocks nn.Sequential(*discriminator_block(channels, 16, bnFalse),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的宽高ds_size img_size // 2 ** 4# 输出层self.adv_layer nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer nn.Sequential(nn.Linear(128 * ds_size ** 2, n_classes), nn.Softmax())def forward(self, img):out self.conv_blocks(img)out out.view(out.shape[0], -1)validity self.adv_layer(out)label self.aux_layer(out)return validity, label# 模型初始化# 损失函数
adversarial_loss nn.BCELoss()
auxiliary_loss nn.CrossEntropyLoss()# 初始化生成器和判别器
generator Generator().to(device)
discriminator Discriminator().to(device)# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)模型训练
# 训练# 优化器
optimizer_G torch.optim.Adam(generator.parameters(), lrlr, betas(b1, b2))
optimizer_D torch.optim.Adam(discriminator.parameters(), lrlr, betas(b1, b2))for epoch in range(n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size imgs.shape[0]# 图像是 真实的 标签valid torch.ones((batch_size, 1), requires_gradFalse, devicedevice)# 图像是 生成的 标签fake torch.zeros((batch_size, 1), requires_gradFalse, devicedevice)real_imgs imgs.to(device)labels labels.to(device)# 训练生成器optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z torch.randn((batch_size, latent_dim), devicedevice)gen_labels torch.randint(0, 1, (batch_size,), devicedevice)# 生成一批图像gen_imgs generator(z, gen_labels)# 损失度量 生成器欺骗判别器的能力validity, pred_label discriminator(gen_imgs)g_loss 0.5 * (adversarial_loss(validity, valid) auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# 训练判别器optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux discriminator(real_imgs)d_real_loss 0.5 * (adversarial_loss(real_pred, valid) auxiliary_loss(real_aux, labels))# 生成图像的损失fake_pred, fake_aux discriminator(gen_imgs.detach())d_fake_loss 0.5 * (adversarial_loss(fake_pred, fake) auxiliary_loss(fake_aux, gen_labels))# 判别器的总损失d_loss 0.5 * (d_real_loss d_fake_loss)# 计算判别器的准确率pred np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis0)gt np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis0)d_acc np.mean(np.argmax(pred, axis1) gt)d_loss.backward()optimizer_D.step()if i % 100 0:my_logger.info([Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f] % (epoch, n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done epoch * len(dataloader) iif batches_done % sample_interval 0:sample_image(n_row10, batches_donebatches_done)模型效果展示 总结与心得体会
通过对模型的复现发现我之前对判别器的理解有偏差如果在判别器的输入中插入分类信息等于是将答案直接给了判别器生成的结果反而不会太好。还有一个和我预想的不一样的地方在生成器中将标签嵌入到特征向量使用了矩阵乘法而没有直接使用concatenate操作。