昇思25天学习打卡营第11天 |昇思MindSpore CycleGAN 图像风格迁移学习
CycleGAN循环对抗生成网络,来自论文。作用:在没有配对示例的情况下学习将图像从源域X转换到目标域Y。应用领域域迁移,通俗理解为图像风格迁移。与 Pix2Pix 的区别Pix2Pix 要求训练数据成对,CycleGAN 则不需要,更适用于现实中难以获取成对图像数据的情况。
·
一、模型介绍
- CycleGAN
- 循环对抗生成网络,来自论文
Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks。 - 作用:在没有配对示例的情况下学习将图像从源域
X转换到目标域Y。
- 循环对抗生成网络,来自论文
- 应用领域
- 域迁移,通俗理解为图像风格迁移。
- 与 Pix2Pix 的区别
- Pix2Pix 要求训练数据成对,CycleGAN 则不需要,更适用于现实中难以获取成对图像数据的情况。
二、模型结构
- 由两个镜像对称的 GAN 网络组成。
- 以苹果和橘子为例,
X为苹果,Y为橘子;G为将苹果生成橘子风格的生成器,F为将橘子生成苹果风格的生成器,D_X和D_Y为其相应判别器。
- 以苹果和橘子为例,
- 关键部分为循环一致损失(Cycle Consistency Loss),确保从一个域转换再转换回来能回到初始状态。
三、数据集
| 名称 | 描述 |
|---|---|
| 来源 | ImageNet |
| 内容 | 只使用了其中的苹果橘子部分 |
| 预处理 | 图像被统一缩放为 256×256 像素大小,进行随机裁剪、水平随机翻转和归一化,并转换为 MindRecord 格式 |
| 数量 | 训练集:苹果 996 张、橘子 1020 张;测试集:苹果 266 张、橘子 248 张 |
四、可视化
使用 matplotlib 模块对训练数据进行可视化。
五、构建生成器
| 函数 | 参数 | 作用 | 示例 |
|---|---|---|---|
ConvNormReLU |
input_channel(输入通道数)、out_planes(输出通道数)、kernel_size(卷积核大小)、stride(步长)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、pad_mode(填充模式)、use_relu(是否使用 ReLU)、padding(填充大小)、transpose(是否为转置卷积) |
进行卷积、归一化和激活操作 | conv_norm_relu = ConvNormReLU(3, 64, 4, 2, 0.2, 'instance', 'CONSTANT', True) |
ResidualBlock |
dim(维度)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式) |
构建残差块 | residual_block = ResidualBlock(64, 'instance', False) |
ResNetGenerator |
input_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(残差块数量)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式)、dropout(是否使用 dropout)、pad_mode(填充模式) |
构建生成器网络 | net_rg = ResNetGenerator(3, 64, 9, 0.2, 'instance', False) |
六、构建判别器
| 函数 | 参数 | 作用 | 示例 |
|---|---|---|---|
Discriminator |
input_channel(输入通道数)、output_channel(初始输出通道数)、n_layers(卷积层数)、alpha(LeakyReLU 的斜率)、norm_mode(归一化模式) |
构建判别器网络 | net_d = Discriminator(3, 64, 3, 0.2, 'instance') |
七、优化器和损失函数
- 生成器和判别器分别使用单独的
Adam优化器。 - 生成器的目标损失函数包括对抗损失和循环一致损失。
- 对抗损失:
L_{GAN}(G,D_Y,X,Y) = E_{y∼p_{data}(y)}[logD_Y(y)] + E_{x∼p_{data}(x)}[log(1 - D_Y(G(x)))] - 循环一致损失:
L_{cyc}(G,F) = E_{x∼p_{data}(x)}[∥F(G(x)) - x∥_1] + E_{y∼p_{data}(y)}[∥G(F(y)) - y∥_1]
- 对抗损失:
八、前向计算
| 函数 | 参数 | 作用 | 示例 |
|---|---|---|---|
generator |
img_a(源域图像)、img_b(目标域图像) |
进行图像生成和转换 | fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b) |
generator_forward |
img_a(源域图像)、img_b(目标域图像) |
计算生成器的损失 | fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b = generator_forward(img_a, img_b) |
discriminator_forward |
img_a(源域真实图像)、img_b(目标域真实图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像) |
计算判别器的损失 | loss_d = discriminator_forward(img_a, img_b, fake_a, fake_b) |
九、计算梯度和反向传播
| 函数 | 参数 | 作用 | 示例 |
|---|---|---|---|
value_and_grad |
待求梯度的函数、梯度相对于哪些输入、待优化的参数 | 计算函数的梯度 | grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params()) |
train_step_g |
img_a(源域图像)、img_b(目标域图像) |
计算生成器的梯度并反向传播更新参数 | fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = train_step_g(img_a, img_b) |
train_step_d |
img_a(源域图像)、img_b(目标域图像)、fake_a(源域生成的假图像)、fake_b(目标域生成的假图像) |
计算判别器的梯度并反向传播更新参数 | loss_d = train_step_d(img_a, img_b, fake_a, fake_b) |
十、模型训练
- 分为训练判别器和生成器两部分。
- 判别器训练目的是提高判别图像真伪的概率。
- 生成器训练目的是产生更好的虚假图像。
- 训练过程中打印损失等信息,并定期保存模型参数。
十一、模型推理
| 函数 | 参数 | 作用 | 示例 |
|---|---|---|---|
load_ckpt |
net(网络)、ckpt_dir(模型参数文件路径) |
加载模型参数 | load_ckpt(net_rg_a, g_a_ckpt) |
eval_data |
dir_path(图像目录路径)、net(网络)、a(偏移量) |
对图像进行推理并展示结果 | eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0) |
昇腾计算产业是基于昇腾系列(HUAWEI Ascend)处理器和基础软件构建的全栈 AI计算基础设施、行业应用及服务,https://devpress.csdn.net/organization/setting/general/146749包括昇腾系列处理器、系列硬件、CANN、AI计算框架、应用使能、开发工具链、管理运维工具、行业应用及服务等全产业链
更多推荐

所有评论(0)