这份文档将 gradioui.ipynb 里的代码拆解开来,用通俗易懂的语言解释每一行代码在做什么。

第一部分:导入库

在编程里,我们不需要从零开始造轮子。我们先“导入”别人写好的强大工具包。

import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import functools
import numpy as np
  • import gradio as gr:
    • Gradio: 这是我们要用的核心工具,专门用来快速搭建可视化网页界面 (Web UI)。
    • as gr: 给它起个短名字叫 gr,以后写 gr.Button 就行,不用写 gradio.Button,偷懒用的。
  • import torch:
    • PyTorch: facebook 开发的深度学习框架。它是我们的人工智能“大脑”。所有复杂的计算、矩阵运算都靠它。
  • import torch.nn as nn:
    • nn (Neural Network): PyTorch 里专门用来搭建神经网络的子模块。比如“卷积层”、“激活函数”都在这里面。
  • import torchvision.transforms as transforms:
    • torchvision: PyTorch 的计算机视觉工具包。
    • transforms: 变形金刚?不,是“图片变换工具”。用来把普通的图片 (比如 JPG) 变成 AI 能看懂的数据格式 (Tensor)。
  • from PIL import Image:
    • PIL (Python Imaging Library): Python 处理图片的基础库。用来打开、保存、调整图片大小。
  • import functools:
    • Python 自带的一个工具箱,主要用于一些高级的函数操作(后面用到 functools.partial 偏函数)。
  • import numpy as np:
    • NumPy: Python 数据科学的基础。专门处理数字矩阵。在这里主要用于把 AI 算出来的结果转回图片数据。

第二部分:绘制图纸 (定义神经网络)

这一大段代码是在定义“生成器” (Generator) 长什么样。你可以把它想象成我们在画一张建筑图纸

2.1 主生成器 ResnetGenerator

class ResnetGenerator(nn.Module):
  • class: 定义一个“类” (Class)。你可以理解为定义一种“模具”或者“蓝图”。
  • (nn.Module): 这个类继承自 PyTorch 的标准神经网络模块。意思是:“我是一个神经网络,我有标准的 forward (前向传播) 功能”。
    def __init__(self, input_nc, output_nc, ngf=64, ...):
        super(ResnetGenerator, self).__init__()
  • __init__: 初始化函数。当你通过这个模具生产一个实物(实例化)时,最先运行这段代码。
  • super(...): 既然继承了爸爸 (nn.Module),初始化时得先喊一声爸爸,让它把基础工作做完。

(中间省略具体的卷积层定义,核心逻辑是:)

  1. 下采样 (Downsampling): 把图片变小,提取特征。就像看图时眯起眼睛,看大轮廓。
  2. 残差块 (Resnet Blocks): 处理特征,转换风格。这是“马变斑马”发生魔法的核心区域。
  3. 上采样 (Upsampling): 把变小后的特征图再放大回去,还原成一张清晰的图片。
    def forward(self, input):
        return self.model(input)
  • forward: 前向传播。这是神经网络“思考”的过程。
  • 当你有数据输入时,它会拿着数据走一遍 self.model 定义的所有层,最后吐出结果。

2.2 积木块 ResnetBlock

class ResnetBlock(nn.Module):
  • 这是上面那个大建筑里的小砖块。
  • ResNet (残差网络) 的核心思想是:把输入的数据直接加到输出上 (x + self.conv_block(x))。
  • 为什么要加? 防止原来的图片信息丢了。比如马变斑马,马的形状、位置不能变,只变条纹。加上原始数据可以保留这些结构信息。

第三部分:组装引擎 (推理类)

前面的代码只是“图纸”,这一部分我们不仅要把机器造出来,还要给它通电、加载经验(加载权重)。

class CycleGANInference:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • self.device: 决定用什么硬件跑。
  • torch.cuda.is_available(): 问电脑:“你有 NVIDIA 显卡 (GPU) 吗?”
    • 如果有,就用 cuda (显卡跑,速度快)。
    • 如果没有,就用 cpu (处理器跑,慢很多)。
        self.netG_h2z = ResnetGenerator(3, 3, n_blocks=9).to(self.device)
        self.netG_a2o = ResnetGenerator(3, 3, n_blocks=9).to(self.device)
  • 这里我们造了两台机器
    • netG_h2z: Horse to Zebra (马转斑马)。
    • netG_a2o: Apple to Orange (苹果转橙子)。
  • ResnetGenerator(3, 3 ...):
    • 输入 3 通道 (红绿蓝 RGB)。
    • 输出 3 通道 (也是 RGB 图片)。
  • .to(self.device): 把这个神经网络搬到显卡(或CPU)内存里去。
        self.load_weights(self.netG_h2z, "model/horse2zebra.pth")
  • load_weights: 调用下面写的加载函数。
  • model/horse2zebra.pth: 这是“经验书” (权重文件)。刚造出来的神经网络是白痴,必须读取训练好的权重文件,它才知道怎么把马变成斑马。
        self.transform = transforms.Compose([
            transforms.Resize((256, 256), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(...)
        ])
  • 预处理流水线:AI 比较挑食,它不吃 JPG 图片,它吃 Tensor (张量)。
    1. Resize: 强制把图片缩放到 256x256 像素 (因为这个模型是按这个尺寸训练的)。
    2. ToTensor: 把图片像素值 (0-255) 变成小数 (0.0-1.0) 并转成 PyTorch 的数据格式。
    3. Normalize: 数学归一化。把数据范围进一步压缩到 (-1, 1) 之间,模型更喜欢处理这种数据。

3.1 预测函数 predict

    def predict(self, input_img, mode):
  • 这是给 Gradio 界面调用的核心功能函数。
  • input_img: 用户上传的图片对象。
  • mode: 用户选的模式 (“马变斑马” 还是 “苹果变橙子”)。
        img_tensor = self.transform(input_img).unsqueeze(0).to(self.device)
  • self.transform: 之前的“流水线”派上用场了,把图片变成数据。
  • unsqueeze(0): 升维。
    • 一张图片的数据形状是 (3, 256, 256) (通道, 高, 宽)。
    • AI 模型通常一次处理一批图片 (Batch),所以需要变成 (1, 3, 256, 256)。这个 1 就是 Batch Size。
        with torch.no_grad():
            output_tensor = model(img_tensor)
  • torch.no_grad(): 告诉 PyTorch:“我现在是在以前学到的知识考试,不是在学习”。这样它就不会计算梯度,省内存,跑得快。
  • model(img_tensor): 真正开始计算。进能不能变出斑马就看这一下。
        output_img = output_tensor.squeeze(0).cpu().float().numpy()
        output_img = (output_img + 1) / 2.0 * 255.0
  • 后处理:把 AI 吐出来的结果(Tensor)变回人类能看的图片。
    • squeeze(0): 把刚才升维加的那个 1 去掉。
    • cpu(): 把数据从显卡拉回内存。
    • numpy(): 转成通用数组格式。
    • (output_img + 1) / 2.0 * 255.0: 反归一化。之前把数据压到了 (-1, 1),现在要还原回 (0, 255) 的颜色值。

第四部分:装修店面 (Gradio 界面)

最后这部分代码是用来画网页界面的。

with gr.Blocks(css=".fixed-height { height: 350px; }") as demo:
  • gr.Blocks: 创建一个积木板,我们可以自由地把按钮、图片框往上堆。
  • css: 写了一点点网页样式,强制图片框的高度为 350px,防止图片忽大忽小界面乱跳。
    mode_selector = gr.Radio(...)
  • gr.Radio: 单选按钮。让用户二选一。
    run_btn = gr.Button("🚀 开始转换", variant="primary")
  • gr.Button: 一个大按钮。
    run_btn.click(
        fn=engine.predict,
        inputs=[input_view, mode_selector],
        outputs=output_view
    )
  • 核心绑定逻辑
    • 当用户点击 run_btn 时…
    • fn (Function): 去执行 engine.predict 这个函数。
    • inputs: 把界面上的 input_view (图片) 和 mode_selector (选项) 里的内容传给函数。
    • outputs: 函数运行完的结果,显示在 output_view 组件里。