网站管理员权限设置权限设置,顺企网杭州网站建设,做网站设像素,广东网站建设模版(一) pytorch 官网内置的网络模型
图像处理#xff1a;
Models and pre-trained weights — Torchvision 0.20 documentation
(二) CIFAR10数据集的分类网络模型#xff08;仅前向传播#xff09;#xff1a; 下方的网络模型图片有误#xff0c;已做修改#xff0c;具…(一) pytorch 官网内置的网络模型
图像处理
Models and pre-trained weights — Torchvision 0.20 documentation
(二) CIFAR10数据集的分类网络模型仅前向传播 下方的网络模型图片有误已做修改具体情参考代码。 1代码如下
无 Sequential() 函数的 demo Sequential() 函数可以快速定义一个前馈神经网路按顺序堆叠不同的层但是要保证层之间的输入和输出尺寸要匹配。
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoaderclass CIFAR10_NET(nn.Module):def __init__(self):super(CIFAR10_NET, self).__init__()self.conv1 nn.Conv2d(3, 32, 5,padding2) # 输入输出尺寸相同故根据公式计算出padding的值self.pool1 nn.MaxPool2d(2, 2)self.conv2 nn.Conv2d(32, 32, 5,padding2)self.pool2 nn.MaxPool2d(2, 2)self.conv3 nn.Conv2d(32, 64, 5,padding2)self.pool3 nn.MaxPool2d(2, 2)self.flatten nn.Flatten()self.linear1 nn.Linear(1024, 64)self.linear2 nn.Linear(64, 10)def forward(self, x):x self.conv1(x)x self.pool1(x)x self.conv2(x)x self.pool2(x)x self.conv3(x)x self.pool3(x)x self.flatten(x)x self.linear1(x)x self.linear2(x)return xCIFAR10_NET_Instance CIFAR10_NET()
print(CIFAR10_NET_Instance)有 Sequential() 函数的 demo
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriterclass CIFAR10_NET(nn.Module):def __init__(self):super(CIFAR10_NET, self).__init__()self.model nn.Sequential(nn.Conv2d(3, 32, 5, padding2), # 输入输出尺寸相同故根据卷积层的公式计算出padding的值此时默认stride1nn.MaxPool2d(2, 2),nn.Conv2d(32, 32, 5, padding2),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, 5, padding2),nn.MaxPool2d(2, 2),nn.Flatten(),nn.Linear(1024, 64),nn.Linear(64, 10))def forward(self, x):x self.model(x)return xCIFAR10_NET_Instance CIFAR10_NET()
print(CIFAR10_NET_Instance)writer SummaryWriter(logs)
writer.add_graph(CIFAR10_NET_Instance, (torch.rand(1, 3, 32, 32), )) # 在tensorboard中将计算图可视化
writer.close()
在命令行使用 tensorboard 的效果图 双击网络模型名 继续双击会出现更多的细节内容
2注意点 如果想要输入和输出的尺寸相同的话需要按照卷积层中的公式来计算 padding 和 stride 的值具体情参考笔记十。 一般先搭建网络在导入数据集之前往往先用以下代码进行测试 # 先创建网络模型实例假设为 test_net
input torch.ones((64,in_channels,H_in,W_in)) # in_channels、H_in、W_in根据数据集的输入设置
output test_net(input)
print(output.shape)如果网络模型有错误就会报错。 上一篇下一篇神经网络入门实战十三待发布