Pytorch实现WGAN用于动漫头像生成


WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像

  import torch  import torch.nn as nn  import torchvision.transforms as transforms  from torch.utils.data import DataLoader  from torchvision.utils import save_image  import os  from anime_face_generator.dataset import ImageDataset     batch_size = 32  num_epoch = 100  z_dimension = 100  dir_path = './wgan_img'     # 创建文件夹  if not os.path.exists(dir_path):    os.mkdir(dir_path)        def to_img(x):    """因为我们在生成器里面用了tanh"""    out = 0.5 * (x + 1)    return out        dataset = ImageDataset()  dataloader = DataLoader(dataset, batch_size=32, shuffle=False)        class Generator(nn.Module):    def __init__(self):      super().__init__()         self.gen = nn.Sequential(        # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map        nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),        nn.BatchNorm2d(512),        nn.ReLU(True),        # 上一步的输出形状:(512) x 4 x 4        nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),        nn.BatchNorm2d(256),        nn.ReLU(True),        # 上一步的输出形状: (256) x 8 x 8        nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),        nn.BatchNorm2d(128),        nn.ReLU(True),        # 上一步的输出形状: (256) x 16 x 16        nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),        nn.BatchNorm2d(64),        nn.ReLU(True),        # 上一步的输出形状:(256) x 32 x 32        nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),        nn.Tanh() # 输出范围 -1~1 故而采用Tanh        # nn.Sigmoid()        # 输出形状:3 x 96 x 96      )       def forward(self, x):      x = self.gen(x)      return x       def weight_init(m):      # weight_initialization: important for wgan      class_name = m.__class__.__name__      if class_name.find('Conv') != -1:        m.weight.data.normal_(0, 0.02)      elif class_name.find('Norm') != -1:        m.weight.data.normal_(1.0, 0.02)        class Discriminator(nn.Module):    def __init__(self):      super().__init__()      self.dis = nn.Sequential(        nn.Conv2d(3, 64, 5, 3, 1, bias=False),        nn.LeakyReLU(0.2, inplace=True),        # 输出 (64) x 32 x 32           nn.Conv2d(64, 128, 4, 2, 1, bias=False),        nn.BatchNorm2d(128),        nn.LeakyReLU(0.2, inplace=True),        # 输出 (128) x 16 x 16           nn.Conv2d(128, 256, 4, 2, 1, bias=False),        nn.BatchNorm2d(256),        nn.LeakyReLU(0.2, inplace=True),        # 输出 (256) x 8 x 8           nn.Conv2d(256, 512, 4, 2, 1, bias=False),        nn.BatchNorm2d(512),        nn.LeakyReLU(0.2, inplace=True),        # 输出 (512) x 4 x 4           nn.Conv2d(512, 1, 4, 1, 0, bias=False),        nn.Flatten(),        # nn.Sigmoid() # 输出一个数(概率)      )       def forward(self, x):      x = self.dis(x)      return x       def weight_init(m):      # weight_initialization: important for wgan      class_name = m.__class__.__name__      if class_name.find('Conv') != -1:        m.weight.data.normal_(0, 0.02)      elif class_name.find('Norm') != -1:        m.weight.data.normal_(1.0, 0.02)        def save(model, filename="model.pt", out_dir="out/"):    if model is not None:      if not os.path.exists(out_dir):        os.mkdir(out_dir)      torch.save({'model': model.state_dict()}, out_dir + filename)    else:      print("[ERROR]:Please build a model!!!")        import QuickModelBuilder as builder     if __name__ == '__main__':    one = torch.FloatTensor([1]).cuda()    mone = -1 * one       is_print = True    # 创建对象    D = Discriminator()    G = Generator()    D.weight_init()    G.weight_init()       if torch.cuda.is_available():      D = D.cuda()      G = G.cuda()       lr = 2e-4    d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )    g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )    d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)    g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)       fake_img = None       # ##########################进入训练##判别器的判断过程#####################    for epoch in range(num_epoch): # 进行多个epoch的训练      pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))      for i, img in enumerate(dataloader):        num_img = img.size(0)        real_img = img.cuda() # 将tensor变成Variable放入计算图中        # 这里的优化器是D的优化器        for param in D.parameters():          param.requires_grad = True        # ########判别器训练train#####################        # 分为两部分:1、真的图像判别为真;2、假的图像判别为假           # 计算真实图片的损失        d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0        real_out = D(real_img) # 将真实图片放入判别器中        d_loss_real = real_out.mean(0).view(1)        d_loss_real.backward(one)           # 计算生成图片的损失        z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声        z = z.reshape(num_img, z_dimension, 1, 1)        fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离        fake_out = D(fake_img) # 判别器判断假的图片,        d_loss_fake = fake_out.mean(0).view(1)        d_loss_fake.backward(mone)           d_loss = d_loss_fake - d_loss_real        d_optimizer.step() # 更新参数           # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01        for parm in D.parameters():          parm.data.clamp_(-0.01, 0.01)           # ==================训练生成器============================        # ###############################生成网络的训练###############################        for param in D.parameters():          param.requires_grad = False           # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D        g_optimizer.zero_grad() # 梯度归0           z = torch.randn(num_img, z_dimension).cuda()        z = z.reshape(num_img, z_dimension, 1, 1)        fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片        output = D(fake_img) # 经过判别器得到的结果        # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss        g_loss = torch.mean(output).view(1)        # bp and optimize        g_loss.backward(one) # 进行反向传播        g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数           # 打印中间的损失        pbar.set_right_info(d_loss=d_loss.data.item(),                  g_loss=g_loss.data.item(),                  real_scores=real_out.data.mean().item(),                  fake_scores=fake_out.data.mean().item(),                  )        pbar.update()        try:          fake_images = to_img(fake_img.cpu())          save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))        except:          pass        if is_print:          is_print = False          real_images = to_img(real_img.cpu())          save_image(real_images, dir_path + '/real_images.png')      pbar.finish()      d_scheduler.step()      g_scheduler.step()      save(D, "wgan_D.pt")      save(G, "wgan_G.pt")  

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!

Pytorch实现WGAN用于动漫头像生成

郑重声明:本网站发布的内容(图片、视频和文字)以及用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服

发表评论

登录后才能评论