这份文档将 gradioui.ipynb 里的代码拆解开来,用通俗易懂的语言解释每一行代码在做什么。
第一部分:导入库
在编程里,我们不需要从零开始造轮子。我们先“导入”别人写好的强大工具包。
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import functools
import numpy as npimport gradio as gr:- Gradio: 这是我们要用的核心工具,专门用来快速搭建可视化网页界面 (Web UI)。
as gr: 给它起个短名字叫gr,以后写gr.Button就行,不用写gradio.Button,偷懒用的。
import torch:- PyTorch: facebook 开发的深度学习框架。它是我们的人工智能“大脑”。所有复杂的计算、矩阵运算都靠它。
import torch.nn as nn:nn(Neural Network): PyTorch 里专门用来搭建神经网络的子模块。比如“卷积层”、“激活函数”都在这里面。
import torchvision.transforms as transforms:torchvision: PyTorch 的计算机视觉工具包。transforms: 变形金刚?不,是“图片变换工具”。用来把普通的图片 (比如 JPG) 变成 AI 能看懂的数据格式 (Tensor)。
from PIL import Image:- PIL (Python Imaging Library): Python 处理图片的基础库。用来打开、保存、调整图片大小。
import functools:- Python 自带的一个工具箱,主要用于一些高级的函数操作(后面用到
functools.partial偏函数)。
- Python 自带的一个工具箱,主要用于一些高级的函数操作(后面用到
import numpy as np:- NumPy: Python 数据科学的基础。专门处理数字矩阵。在这里主要用于把 AI 算出来的结果转回图片数据。
第二部分:绘制图纸 (定义神经网络)
这一大段代码是在定义“生成器” (Generator) 长什么样。你可以把它想象成我们在画一张建筑图纸。
2.1 主生成器 ResnetGenerator
class ResnetGenerator(nn.Module):class: 定义一个“类” (Class)。你可以理解为定义一种“模具”或者“蓝图”。(nn.Module): 这个类继承自 PyTorch 的标准神经网络模块。意思是:“我是一个神经网络,我有标准的 forward (前向传播) 功能”。
def __init__(self, input_nc, output_nc, ngf=64, ...):
super(ResnetGenerator, self).__init__()__init__: 初始化函数。当你通过这个模具生产一个实物(实例化)时,最先运行这段代码。super(...): 既然继承了爸爸 (nn.Module),初始化时得先喊一声爸爸,让它把基础工作做完。
(中间省略具体的卷积层定义,核心逻辑是:)
- 下采样 (Downsampling): 把图片变小,提取特征。就像看图时眯起眼睛,看大轮廓。
- 残差块 (Resnet Blocks): 处理特征,转换风格。这是“马变斑马”发生魔法的核心区域。
- 上采样 (Upsampling): 把变小后的特征图再放大回去,还原成一张清晰的图片。
def forward(self, input):
return self.model(input)forward: 前向传播。这是神经网络“思考”的过程。- 当你有数据输入时,它会拿着数据走一遍
self.model定义的所有层,最后吐出结果。
2.2 积木块 ResnetBlock
class ResnetBlock(nn.Module):- 这是上面那个大建筑里的小砖块。
- ResNet (残差网络) 的核心思想是:把输入的数据直接加到输出上 (
x + self.conv_block(x))。 - 为什么要加? 防止原来的图片信息丢了。比如马变斑马,马的形状、位置不能变,只变条纹。加上原始数据可以保留这些结构信息。
第三部分:组装引擎 (推理类)
前面的代码只是“图纸”,这一部分我们不仅要把机器造出来,还要给它通电、加载经验(加载权重)。
class CycleGANInference:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.device: 决定用什么硬件跑。torch.cuda.is_available(): 问电脑:“你有 NVIDIA 显卡 (GPU) 吗?”- 如果有,就用
cuda(显卡跑,速度快)。 - 如果没有,就用
cpu(处理器跑,慢很多)。
- 如果有,就用
self.netG_h2z = ResnetGenerator(3, 3, n_blocks=9).to(self.device)
self.netG_a2o = ResnetGenerator(3, 3, n_blocks=9).to(self.device)- 这里我们造了两台机器!
netG_h2z: Horse to Zebra (马转斑马)。netG_a2o: Apple to Orange (苹果转橙子)。
ResnetGenerator(3, 3 ...):- 输入 3 通道 (红绿蓝 RGB)。
- 输出 3 通道 (也是 RGB 图片)。
.to(self.device): 把这个神经网络搬到显卡(或CPU)内存里去。
self.load_weights(self.netG_h2z, "model/horse2zebra.pth")load_weights: 调用下面写的加载函数。model/horse2zebra.pth: 这是“经验书” (权重文件)。刚造出来的神经网络是白痴,必须读取训练好的权重文件,它才知道怎么把马变成斑马。
self.transform = transforms.Compose([
transforms.Resize((256, 256), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(...)
])- 预处理流水线:AI 比较挑食,它不吃 JPG 图片,它吃 Tensor (张量)。
- Resize: 强制把图片缩放到 256x256 像素 (因为这个模型是按这个尺寸训练的)。
- ToTensor: 把图片像素值 (0-255) 变成小数 (0.0-1.0) 并转成 PyTorch 的数据格式。
- Normalize: 数学归一化。把数据范围进一步压缩到 (-1, 1) 之间,模型更喜欢处理这种数据。
3.1 预测函数 predict
def predict(self, input_img, mode):- 这是给 Gradio 界面调用的核心功能函数。
input_img: 用户上传的图片对象。mode: 用户选的模式 (“马变斑马” 还是 “苹果变橙子”)。
img_tensor = self.transform(input_img).unsqueeze(0).to(self.device)self.transform: 之前的“流水线”派上用场了,把图片变成数据。unsqueeze(0): 升维。- 一张图片的数据形状是
(3, 256, 256)(通道, 高, 宽)。 - AI 模型通常一次处理一批图片 (Batch),所以需要变成
(1, 3, 256, 256)。这个1就是 Batch Size。
- 一张图片的数据形状是
with torch.no_grad():
output_tensor = model(img_tensor)torch.no_grad(): 告诉 PyTorch:“我现在是在以前学到的知识考试,不是在学习”。这样它就不会计算梯度,省内存,跑得快。model(img_tensor): 真正开始计算。进能不能变出斑马就看这一下。
output_img = output_tensor.squeeze(0).cpu().float().numpy()
output_img = (output_img + 1) / 2.0 * 255.0- 后处理:把 AI 吐出来的结果(Tensor)变回人类能看的图片。
squeeze(0): 把刚才升维加的那个1去掉。cpu(): 把数据从显卡拉回内存。numpy(): 转成通用数组格式。(output_img + 1) / 2.0 * 255.0: 反归一化。之前把数据压到了 (-1, 1),现在要还原回 (0, 255) 的颜色值。
第四部分:装修店面 (Gradio 界面)
最后这部分代码是用来画网页界面的。
with gr.Blocks(css=".fixed-height { height: 350px; }") as demo:gr.Blocks: 创建一个积木板,我们可以自由地把按钮、图片框往上堆。css: 写了一点点网页样式,强制图片框的高度为 350px,防止图片忽大忽小界面乱跳。
mode_selector = gr.Radio(...)gr.Radio: 单选按钮。让用户二选一。
run_btn = gr.Button("🚀 开始转换", variant="primary")gr.Button: 一个大按钮。
run_btn.click(
fn=engine.predict,
inputs=[input_view, mode_selector],
outputs=output_view
)- 核心绑定逻辑:
- 当用户点击
run_btn时… fn(Function): 去执行engine.predict这个函数。inputs: 把界面上的input_view(图片) 和mode_selector(选项) 里的内容传给函数。outputs: 函数运行完的结果,显示在output_view组件里。
- 当用户点击
