一、模型介绍

  1. CycleGAN
    • 循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
    • 作用:在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y
  2. 应用领域
    • 域迁移,通俗理解为图像风格迁移。
  3. 与 Pix2Pix 的区别
    • Pix2Pix 要求训练数据成对,CycleGAN 则不需要,更适用于现实中难以获取成对图像数据的情况。

二、模型结构

  1. 由两个镜像对称的 GAN 网络组成。
    • 以苹果和橘子为例,X 为苹果,Y 为橘子;G 为将苹果生成橘子风格的生成器,F 为将橘子生成苹果风格的生成器,D_XD_Y 为其相应判别器。
  2. 关键部分为循环一致损失(Cycle Consistency Loss),确保从一个域转换再转换回来能回到初始状态。

三、数据集

名称 描述
来源 ImageNet
内容 只使用了其中的苹果橘子部分
预处理 图像被统一缩放为 256×256 像素大小,进行随机裁剪、水平随机翻转和归一化,并转换为 MindRecord 格式
数量 训练集:苹果 996 张、橘子 1020 张;测试集:苹果 266 张、橘子 248 张

四、可视化

使用 matplotlib 模块对训练数据进行可视化。

五、构建生成器

函数 参数 作用 示例
ConvNormReLU input_channel(输入通道数)、out_planes(输出通道数)、kernel_size(卷积核大小)、stride(步长)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、pad_mode(填充模式)、use_relu(是否使用 ReLU)、padding(填充大小)、transpose(是否为转置卷积) 进行卷积、归一化和激活操作 conv_norm_relu = ConvNormReLU(3, 64, 4, 2, 0.2, 'instance', 'CONSTANT', True)
ResidualBlock dim(维度)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式) 构建残差块 residual_block = ResidualBlock(64, 'instance', False)
ResNetGenerator input_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(残差块数量)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式) 构建生成器网络 net_rg = ResNetGenerator(3, 64, 9, 0.2, 'instance', False)

六、构建判别器

函数 参数 作用 示例
Discriminator input_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(卷积层数)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式) 构建判别器网络 net_d = Discriminator(3, 64, 3, 0.2, 'instance')

七、优化器和损失函数

  1. 生成器和判别器分别使用单独的 Adam 优化器。
  2. 生成器的目标损失函数包括对抗损失和循环一致损失。
    • 对抗损失:L_{GAN}(G,D_Y,X,Y) = E_{y∼p_{data}(y)}[logD_Y(y)] + E_{x∼p_{data}(x)}[log(1 - D_Y(G(x)))]
    • 循环一致损失:L_{cyc}(G,F) = E_{x∼p_{data}(x)}[∥F(G(x)) - x∥_1] + E_{y∼p_{data}(y)}[∥G(F(y)) - y∥_1]

八、前向计算

函数 参数 作用 示例
generator img_a(源域图像)、img_b(目标域图像) 进行图像生成和转换 fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
generator_forward img_a(源域图像)、img_b(目标域图像) 计算生成器的损失 fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b = generator_forward(img_a, img_b)
discriminator_forward img_a(源域真实图像)、img_b(目标域真实图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像) 计算判别器的损失 loss_d = discriminator_forward(img_a, img_b, fake_a, fake_b)

九、计算梯度和反向传播

函数 参数 作用 示例
value_and_grad 待求梯度的函数、梯度相对于哪些输入、待优化的参数 计算函数的梯度 grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
train_step_g img_a(源域图像)、img_b(目标域图像) 计算生成器的梯度并反向传播更新参数 fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = train_step_g(img_a, img_b)
train_step_d img_a(源域图像)、img_b(目标域图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像) 计算判别器的梯度并反向传播更新参数 loss_d = train_step_d(img_a, img_b, fake_a, fake_b)

十、模型训练

  1. 分为训练判别器和生成器两部分。
    • 判别器训练目的是提高判别图像真伪的概率。
    • 生成器训练目的是产生更好的虚假图像。
  2. 训练过程中打印损失等信息,并定期保存模型参数。

十一、模型推理

函数 参数 作用 示例
load_ckpt net(网络)、ckpt_dir(模型参数文件路径) 加载模型参数 load_ckpt(net_rg_a, g_a_ckpt)
eval_data dir_path(图像目录路径)、net(网络)、a(偏移量) 对图像进行推理并展示结果 eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
Logo

昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链

更多推荐