音效素材网提供各类素材,打造精品素材网站!

站内导航 站长工具 投稿中心 手机访问

音效素材

超详细PyTorch实现手写数字识别器的示例代码
日期:2021-09-08 13:54:58   来源:脚本之家

前言

深度学习中有很多玩具数据,mnist就是其中一个,一个人能否入门深度学习往往就是以能否玩转mnist数据来判断的,在前面很多基础介绍后我们就可以来实现一个简单的手写数字识别的网络了

数据的处理

我们使用pytorch自带的包进行数据的预处理

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5), (0.5))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True,num_workers=2)

注释:transforms.Normalize用于数据的标准化,具体实现
mean:均值 总和后除个数
std:方差 每个元素减去均值再平方再除个数

norm_data = (tensor - mean) / std

这里就直接将图片标准化到了-1到1的范围,标准化的原因就是因为如果某个数在数据中很大很大,就导致其权重较大,从而影响到其他数据,而本身我们的数据都是平等的,所以标准化后将数据分布到-1到1的范围,使得所有数据都不会有太大的权重导致网络出现巨大的波动
trainloader现在是一个可迭代的对象,那么我们可以使用for循环进行遍历了,由于是使用yield返回的数据,为了节约内存

观察一下数据

def imshow(img):
   img = img / 2 + 0.5 # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()
# torchvision.utils.make_grid 将图片进行拼接
imshow(torchvision.utils.make_grid(iter(trainloader).next()[0]))

在这里插入图片描述

构建网络

from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=28, kernel_size=5) # 14
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # 无参数学习因此无需设置两个
    self.conv2 = nn.Conv2d(in_channels=28, out_channels=28*2, kernel_size=5) # 7
    self.fc1 = nn.Linear(in_features=28*2*4*4, out_features=1024)
    self.fc2 = nn.Linear(in_features=1024, out_features=10)
  def forward(self, inputs):
    x = self.pool(F.relu(self.conv1(inputs)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(inputs.size()[0],-1)
    x = F.relu(self.fc1(x))
    return self.fc2(x)

下面是卷积的动态演示

在这里插入图片描述

in_channels:为输入通道数 彩色图片有3个通道 黑白有1个通道
out_channels:输出通道数
kernel_size:卷积核的大小
stride:卷积的步长
padding:外边距大小

输出的size计算公式

  • h = (h - kernel_size + 2*padding)/stride + 1
  • w = (w - kernel_size + 2*padding)/stride + 1

MaxPool2d:是没有参数进行运算的

在这里插入图片描述

实例化网络优化器,并且使用GPU进行训练

net = Net()
opt = torch.optim.Adam(params=net.parameters(), lr=0.001)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
Net(
 (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1))
 (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (conv2): Conv2d(28, 56, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=896, out_features=1024, bias=True)
 (fc2): Linear(in_features=1024, out_features=10, bias=True)
)

训练主要代码

for epoch in range(50):
  for images, labels in trainloader:
    images = images.to(device)
    labels = labels.to(device)
    pre_label = net(images)
    loss = F.cross_entropy(input=pre_label, target=labels).mean()
    pre_label = torch.argmax(pre_label, dim=1)
    acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
    net.zero_grad()
    loss.backward()
    opt.step()
  print(acc.detach().cpu().numpy(), loss.detach().cpu().numpy())

F.cross_entropy交叉熵函数

在这里插入图片描述

源码中已经帮助我们实现了softmax因此不需要自己进行softmax操作了
torch.argmax计算最大数所在索引值

acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
# pre_label==labels 相同维度进行比较相同返回True不同的返回False,True为1 False为0, 即可获取到相等的个数,再除总个数,就得到了Accuracy准确度了

预测

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True,num_workers=2)
images, labels = iter(testloader).next()
images = images.to(device)
labels = labels.to(device)
with torch.no_grad():
  pre_label = net(images)
  pre_label = torch.argmax(pre_label, dim=1)
  acc = (pre_label==labels).sum()/torch.tensor(labels.size()[0], dtype=torch.float32)
  print(acc)

总结

本节我们了解了标准化数据·卷积的原理简答的构建了一个网络,并让它去识别手写体,也是对前面章节的总汇了

到此这篇关于超详细PyTorch实现手写数字识别器的示例代码的文章就介绍到这了,更多相关PyTorch 手写数字识别器内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

    您感兴趣的教程

    在docker中安装mysql详解

    本篇文章主要介绍了在docker中安装mysql详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编...

    详解 安装 docker mysql

    win10中文输入法仅在桌面显示怎么办?

    win10中文输入法仅在桌面显示怎么办?

    win10系统使用搜狗,QQ输入法只有在显示桌面的时候才出来,在使用其他程序输入框里面却只能输入字母数字,win10中...

    win10 中文输入法

    一分钟掌握linux系统目录结构

    这篇文章主要介绍了linux系统目录结构,通过结构图和多张表格了解linux系统目录结构,感兴趣的小伙伴们可以参考一...

    结构 目录 系统 linux

    PHP程序员玩转Linux系列 Linux和Windows安装

    这篇文章主要为大家详细介绍了PHP程序员玩转Linux系列文章,Linux和Windows安装nginx教程,具有一定的参考价值,感兴趣...

    玩转 程序员 安装 系列 PHP

    win10怎么安装杜比音效Doby V4.1 win10安装杜

    第四代杜比®家庭影院®技术包含了一整套协同工作的技术,让PC 发出清晰的环绕声同时第四代杜比家庭影院技术...

    win10杜比音效

    纯CSS实现iOS风格打开关闭选择框功能

    这篇文章主要介绍了纯CSS实现iOS风格打开关闭选择框,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作...

    css ios c

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的办法

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的

    Win7给电脑C盘扩容的办法大家知道吗?当系统分区C盘空间不足时,就需要给它扩容了,如果不管,C盘没有足够的空间...

    Win7 C盘 扩容

    百度推广竞品词的投放策略

    SEM是基于关键词搜索的营销活动。作为推广人员,我们所做的工作,就是打理成千上万的关键词,关注它们的质量度...

    百度推广 竞品词

    Visual Studio Code(vscode) git的使用教程

    这篇文章主要介绍了详解Visual Studio Code(vscode) git的使用,小编觉得挺不错的,现在分享给大家,也给大家做个参考。...

    教程 Studio Visual Code git

    七牛云储存创始人分享七牛的创立故事与

    这篇文章主要介绍了七牛云储存创始人分享七牛的创立故事与对Go语言的应用,七牛选用Go语言这门新兴的编程语言进行...

    七牛 Go语言

    Win10预览版Mobile 10547即将发布 9月19日上午

    微软副总裁Gabriel Aul的Twitter透露了 Win10 Mobile预览版10536即将发布,他表示该版本已进入内部慢速版阶段,发布时间目...

    Win10 预览版

    HTML标签meta总结,HTML5 head meta 属性整理

    移动前端开发中添加一些webkit专属的HTML5头部标签,帮助浏览器更好解析HTML代码,更好地将移动web前端页面表现出来...

    移动端html5模拟长按事件的实现方法

    这篇文章主要介绍了移动端html5模拟长按事件的实现方法的相关资料,小编觉得挺不错的,现在分享给大家,也给大家...

    移动端 html5 长按

    HTML常用meta大全(推荐)

    这篇文章主要介绍了HTML常用meta大全(推荐),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参...

    cdr怎么把图片转换成位图? cdr图片转换为位图的教程

    cdr怎么把图片转换成位图? cdr图片转换为

    cdr怎么把图片转换成位图?cdr中插入的图片想要转换成位图,该怎么转换呢?下面我们就来看看cdr图片转换为位图的...

    cdr 图片 位图

    win10系统怎么录屏?win10系统自带录屏详细教程

    win10系统怎么录屏?win10系统自带录屏详细

    当我们是使用win10系统的时候,想要录制电脑上的画面,这时候有人会想到下个第三方软件,其实可以用电脑上的自带...

    win10 系统自带录屏 详细教程

    + 更多教程 +
    ASP编程JSP编程PHP编程.NET编程python编程