擴散模型課程文件
製作一個類條件擴散模型
並獲得增強的文件體驗
開始使用
製作一個類條件擴散模型
在這個 notebook 中,我們將演示一種向擴散模型新增條件資訊的方法。具體來說,我們將在 MNIST 上訓練一個類條件擴散模型,這是繼單元 1 中的“從零開始”示例之後的內容。在推理時,我們可以指定希望模型生成哪個數字。
正如本單元介紹中提到的,這只是向擴散模型新增額外條件資訊的眾多方法之一,選擇這種方法是因為它相對簡單。就像單元 1 中的“從零開始”的 notebook 一樣,這個 notebook 主要用於演示目的,如果你願意,可以安全地跳過它。
設定和資料準備
>>> %pip install -q diffusers[K |████████████████████████████████| 503 kB 7.2 MB/s [K |████████████████████████████████| 182 kB 51.3 MB/s [?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> from tqdm.auto import tqdm
>>> device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
>>> print(f"Using device: {device}")Using device: cuda
>>> # Load the dataset
>>> dataset = torchvision.datasets.MNIST(
... root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )
>>> # Feed it into a dataloader (batch size 8 here just for demo)
>>> train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
>>> # View some examples
>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
建立一個類條件 UNet
我們將透過以下方式輸入類別條件:
- 建立一個標準的
UNet2DModel,並增加一些額外的輸入通道 - 透過嵌入層將類別標籤對映為一個形狀為
(class_emb_size)的學習向量 - 使用
net_input = torch.cat((x, class_cond), 1)將此資訊作為額外通道與 UNet 的內部輸入連線起來 - 將這個
net_input(總共有 (class_emb_size+1) 個通道)輸入到 UNet 中以獲得最終預測
在這個例子中,我將 class_emb_size 設定為 4,但這完全是隨意的,你可以探索將其設定為 1(看是否仍然有效)、10(與類別數量匹配),或者用類別標籤的簡單 one-hot 編碼直接替換學習的 nn.Embedding。
這是實現的樣子
class ClassConditionedUnet(nn.Module):
def __init__(self, num_classes=10, class_emb_size=4):
super().__init__()
# The embedding layer will map the class label to a vector of size class_emb_size
self.class_emb = nn.Embedding(num_classes, class_emb_size)
# Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
self.model = UNet2DModel(
sample_size=28, # the target image resolution
in_channels=1 + class_emb_size, # Additional input channels for class cond.
out_channels=1, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(32, 64, 64),
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D", # a regular ResNet upsampling block
),
)
# Our forward method now takes the class labels as an additional argument
def forward(self, x, t, class_labels):
# Shape of x:
bs, ch, w, h = x.shape
# class conditioning in right shape to add as additional input channels
class_cond = self.class_emb(class_labels) # Map to embedding dimension
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
# x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
# Net input is now x and class cond concatenated together along dimension 1
net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)
# Feed this to the UNet alongside the timestep and return the prediction
return self.model(net_input, t).sample # (bs, 1, 28, 28)如果任何形狀或變換讓你感到困惑,可以新增 print 語句來顯示相關的形狀,並檢查它們是否符合你的預期。為了讓事情更清晰,我還註釋了一些中間變數的形狀。
訓練和取樣
之前我們會做類似 prediction = unet(x, t) 的操作,現在我們會在訓練時將正確的標籤作為第三個引數加入(prediction = unet(x, t, y)),而在推理時,我們可以傳遞任何我們想要的標籤,如果一切順利,模型應該會生成匹配的影像。在這種情況下,y 是 MNIST 數字的標籤,值為 0 到 9。
訓練迴圈與單元 1 中的示例非常相似。我們現在預測的是噪聲(而不是像單元 1 中那樣預測去噪後的影像),以匹配預設的 DDPMScheduler 所期望的目標,我們用它在訓練期間新增噪聲並在推理時生成樣本。訓練需要一些時間——加快這個過程可能是一個有趣的小專案,但大多數人可能只需瀏覽程式碼(以及整個 notebook)而無需執行它,因為我們只是在闡述一個想法。
# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")>>> # @markdown Training loop (10 Epochs):
>>> # Redefining the dataloader to set the batch size higher than the demo of 8
>>> train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 10
>>> # Our network
>>> net = ClassConditionedUnet().to(device)
>>> # Our loss function
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in tqdm(train_dataloader):
... # Get some data and prepare the corrupted version
... x = x.to(device) * 2 - 1 # Data on the GPU (mapped to (-1, 1))
... y = y.to(device)
... noise = torch.randn_like(x)
... timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
... noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
... # Get the model prediction
... pred = net(noisy_x, timesteps, y) # Note that we pass in the labels y
... # Calculate the loss
... loss = loss_fn(pred, noise) # How close is the output to the noise
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print out the average of the last 100 loss values to get an idea of progress:
... avg_loss = sum(losses[-100:]) / 100
... print(f"Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}")
>>> # View the loss curve
>>> plt.plot(losses)Finished epoch 0. Average of the last 100 loss values: 0.052451
訓練完成後,我們可以透過輸入不同的標籤作為條件來取樣一些影像
>>> # @markdown Sampling some different digits:
>>> # Prepare random x to start from, plus some desired labels y
>>> x = torch.randn(80, 1, 28, 28).to(device)
>>> y = torch.tensor([[i] * 8 for i in range(10)]).flatten().to(device)
>>> # Sampling loop
>>> for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
... # Get model pred
... with torch.no_grad():
... residual = net(x, t, y) # Again, note that we pass in our labels y
... # Update sample with step
... x = noise_scheduler.step(residual, t, x).prev_sample
>>> # Show the results
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap="Greys")就是這樣!我們現在可以對生成的影像進行一些控制了。
希望你喜歡這個例子。與往常一樣,歡迎在 Discord 中提問。
# Exercise (optional): Try this with FashionMNIST. Tweak the learning rate, batch size and number of epochs.
# Can you get some decent-looking fashion images with less training time than the example above?