音效素材网提供各类素材,打造精品素材网站!

站内导航 站长工具 投稿中心 手机访问

音效素材

Pytorch实现WGAN用于动漫头像生成
日期:2021-09-08 13:39:24   来源:脚本之家

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

    您感兴趣的教程

    在docker中安装mysql详解

    本篇文章主要介绍了在docker中安装mysql详解,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编...

    详解 安装 docker mysql

    win10中文输入法仅在桌面显示怎么办?

    win10中文输入法仅在桌面显示怎么办?

    win10系统使用搜狗,QQ输入法只有在显示桌面的时候才出来,在使用其他程序输入框里面却只能输入字母数字,win10中...

    win10 中文输入法

    一分钟掌握linux系统目录结构

    这篇文章主要介绍了linux系统目录结构,通过结构图和多张表格了解linux系统目录结构,感兴趣的小伙伴们可以参考一...

    结构 目录 系统 linux

    PHP程序员玩转Linux系列 Linux和Windows安装

    这篇文章主要为大家详细介绍了PHP程序员玩转Linux系列文章,Linux和Windows安装nginx教程,具有一定的参考价值,感兴趣...

    玩转 程序员 安装 系列 PHP

    win10怎么安装杜比音效Doby V4.1 win10安装杜

    第四代杜比®家庭影院®技术包含了一整套协同工作的技术,让PC 发出清晰的环绕声同时第四代杜比家庭影院技术...

    win10杜比音效

    纯CSS实现iOS风格打开关闭选择框功能

    这篇文章主要介绍了纯CSS实现iOS风格打开关闭选择框,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作...

    css ios c

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的办法

    Win7如何给C盘扩容 Win7系统电脑C盘扩容的

    Win7给电脑C盘扩容的办法大家知道吗?当系统分区C盘空间不足时,就需要给它扩容了,如果不管,C盘没有足够的空间...

    Win7 C盘 扩容

    百度推广竞品词的投放策略

    SEM是基于关键词搜索的营销活动。作为推广人员,我们所做的工作,就是打理成千上万的关键词,关注它们的质量度...

    百度推广 竞品词

    Visual Studio Code(vscode) git的使用教程

    这篇文章主要介绍了详解Visual Studio Code(vscode) git的使用,小编觉得挺不错的,现在分享给大家,也给大家做个参考。...

    教程 Studio Visual Code git

    七牛云储存创始人分享七牛的创立故事与

    这篇文章主要介绍了七牛云储存创始人分享七牛的创立故事与对Go语言的应用,七牛选用Go语言这门新兴的编程语言进行...

    七牛 Go语言

    Win10预览版Mobile 10547即将发布 9月19日上午

    微软副总裁Gabriel Aul的Twitter透露了 Win10 Mobile预览版10536即将发布,他表示该版本已进入内部慢速版阶段,发布时间目...

    Win10 预览版

    HTML标签meta总结,HTML5 head meta 属性整理

    移动前端开发中添加一些webkit专属的HTML5头部标签,帮助浏览器更好解析HTML代码,更好地将移动web前端页面表现出来...

    移动端html5模拟长按事件的实现方法

    这篇文章主要介绍了移动端html5模拟长按事件的实现方法的相关资料,小编觉得挺不错的,现在分享给大家,也给大家...

    移动端 html5 长按

    HTML常用meta大全(推荐)

    这篇文章主要介绍了HTML常用meta大全(推荐),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参...

    cdr怎么把图片转换成位图? cdr图片转换为位图的教程

    cdr怎么把图片转换成位图? cdr图片转换为

    cdr怎么把图片转换成位图?cdr中插入的图片想要转换成位图,该怎么转换呢?下面我们就来看看cdr图片转换为位图的...

    cdr 图片 位图

    win10系统怎么录屏?win10系统自带录屏详细教程

    win10系统怎么录屏?win10系统自带录屏详细

    当我们是使用win10系统的时候,想要录制电脑上的画面,这时候有人会想到下个第三方软件,其实可以用电脑上的自带...

    win10 系统自带录屏 详细教程

    + 更多教程 +
    ASP编程JSP编程PHP编程.NET编程python编程