擴散模型課程文件
從零開始構建擴散模型
並獲得增強的文件體驗
開始使用
從零開始構建擴散模型
有時候,為了更好地理解一個事物的工作原理,研究其最簡單的可能版本會很有幫助。在本筆記本中,我們將嘗試這樣做,從一個“玩具級”的擴散模型開始,看看不同部分是如何工作的,然後檢查它們與更復雜實現有何不同。
我們將探討:
- 破壞過程(向資料新增噪聲)
- 什麼是 UNet,以及如何從零開始實現一個極其精簡的 UNet
- 擴散模型的訓練
- 取樣理論
然後,我們會將我們的版本與 diffusers 的 DDPM 實現進行比較,並探討:
- 對我們迷你 UNet 的改進
- DDPM 的噪聲排程
- 訓練目標的差異
- 時間步條件
- 取樣方法
這個筆記本內容相當深入,如果你對從零開始的深度探索不感興趣,可以放心跳過!
同樣值得注意的是,這裡的大部分程式碼都是為了說明目的,我不建議直接在你的工作中使用它們(除非你只是為了學習目的而嘗試改進這裡展示的例子)。
設定與匯入:
>>> %pip install -q diffusers[K |████████████████████████████████| 255 kB 16.0 MB/s [K |████████████████████████████████| 163 kB 53.9 MB/s [?25h
>>> import torch
>>> import torchvision
>>> from torch import nn
>>> from torch.nn import functional as F
>>> from torch.utils.data import DataLoader
>>> from diffusers import DDPMScheduler, UNet2DModel
>>> from matplotlib import pyplot as plt
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> print(f"Using device: {device}")Using device: cuda
資料
在這裡,我們將使用一個非常小的資料集進行測試:mnist。如果你想在不改變其他任何設定的情況下給模型一個稍難的挑戰,torchvision.datasets.FashionMNIST 應該可以作為直接的替代品。
>>> dataset = torchvision.datasets.MNIST(
... root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor()
... )Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)>>> x, y = next(iter(train_dataloader))
>>> print("Input shape:", x.shape)
>>> print("Labels:", y)
>>> plt.imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")Input shape: torch.Size([8, 1, 28, 28]) Labels: tensor([1, 9, 7, 3, 5, 2, 1, 4])
每張圖片是一幅 28x28 畫素的灰度數字手寫圖,畫素值範圍從 0 到 1。
破壞過程
假設你沒有讀過任何關於擴散模型的論文,但你知道這個過程涉及到新增噪聲。你會怎麼做呢?
我們可能希望有一種簡單的方法來控制破壞的程度。那麼,如果我們引入一個引數 amount 來表示要新增的噪聲量,然後這樣做會怎麼樣?
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount = 0,我們會得到沒有任何改變的原始輸入。如果 amount 達到 1,我們得到的是沒有任何原始輸入 x 痕跡的噪聲。透過這種方式混合輸入和噪聲,我們能將輸出保持在相同的範圍(0 到 1)內。
我們可以相當容易地實現這一點(只需注意形狀,以免被廣播規則坑到):
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x * (1 - amount) + noise * amount並透過視覺化結果來檢查它是否按預期工作:
>>> # Plotting the input data
>>> fig, axs = plt.subplots(2, 1, figsize=(12, 5))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap="Greys")
>>> # Adding noise
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Plotting the noised version
>>> axs[1].set_title("Corrupted data (-- amount increases -->)")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap="Greys")當噪聲量接近 1 時,我們的資料開始看起來像純粹的隨機噪聲。但對於大多數噪聲量,你仍然可以相當準確地猜出數字。你認為這是最優的嗎?
模型
我們希望有一個模型,它能接收 28 畫素的帶噪影像,並輸出一個相同形狀的預測結果。這裡一個流行的選擇是名為 UNet 的架構。UNet 最初是為醫學影像中的分割任務而發明的,它包含一個“收縮路徑”(資料在此路徑上被壓縮)和一個“擴充套件路徑”(資料在此路徑上恢復到原始維度,類似於自編碼器),但它還具有跳躍連線,允許資訊和梯度在不同層級之間流動。
一些 UNet 在每個階段都具有複雜的模組,但對於這個玩具級演示,我們將構建一個極簡的例子,它接收單通道影像,在下采樣路徑(圖表和程式碼中的 down_layers)中透過三個卷積層,在上取樣路徑中透過三個卷積層,並且在下采樣和上取樣層之間有跳躍連線。我們將使用最大池化進行下采樣,使用 nn.Upsample 進行上取樣,而不是像更復雜的 UNet 那樣依賴可學習的層。下面是大致的架構,顯示了每一層輸出的通道數:
這就是它在程式碼中的樣子:
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList(
[
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
]
)
self.up_layers = torch.nn.ModuleList(
[
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
]
)
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x我們可以驗證輸出形狀與輸入相同,正如我們所期望的:
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape這個網路有超過 30 萬個引數。
sum([p.numel() for p in net.parameters()])如果你願意,可以嘗試改變每層的通道數或換用不同的架構。
訓練網路
那麼,這個模型具體應該做什麼呢?同樣,對此有不同的看法,但在這個演示中,我們選擇一個簡單的框架:給定一個被破壞的輸入 noisy_x,模型應該輸出它對原始 x 的最佳猜測。我們將透過均方誤差(mean squared error)將這個猜測與真實值進行比較。
我們現在可以嘗試訓練網路了。
- 獲取一批資料
- 以隨機的量破壞它
- 將其輸入模型
- 將模型預測與清晰影像進行比較,以計算我們的損失
- 相應地更新模型的引數。
你可以隨意修改這個過程,看看能否讓它工作得更好!
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = BasicUNet()
>>> net.to(device)
>>> # Our loss function
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x)
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # View the loss curve
>>> plt.plot(losses)
>>> plt.ylim(0, 0.1)Finished epoch 0. Average loss for this epoch: 0.026736 Finished epoch 1. Average loss for this epoch: 0.020692 Finished epoch 2. Average loss for this epoch: 0.018887
我們可以透過取一批資料,以不同的量破壞它,然後觀察模型的預測結果,來看看模型的表現如何。
>>> # @markdown Visualizing model predictions on noisy inputs:
>>> # Fetch some data
>>> x, y = next(iter(train_dataloader))
>>> x = x[:8] # Only using the first 8 for easy plotting
>>> # Corrupt with a range of amounts
>>> amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
>>> noised_x = corrupt(x, amount)
>>> # Get the model predictions
>>> with torch.no_grad():
... preds = net(noised_x.to(device)).detach().cpu()
>>> # Plot
>>> fig, axs = plt.subplots(3, 1, figsize=(12, 7))
>>> axs[0].set_title("Input data")
>>> axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Corrupted data")
>>> axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap="Greys")
>>> axs[2].set_title("Network Predictions")
>>> axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap="Greys")你可以看到,對於較低的噪聲量,預測結果相當不錯!但隨著噪聲水平變得非常高,模型可利用的資訊越來越少,當噪聲量達到 1 時,它會輸出一個接近資料集均值的模糊影像,試圖在不確定的情況下做出最穩妥的猜測……
取樣
如果我們在高噪聲水平下的預測不是很好,我們該如何生成影像呢?
嗯,如果我們從隨機噪聲開始,觀察模型的預測,但只朝著那個預測移動一小部分——比如說,20% 的距離。現在我們有了一張非常嘈雜的影像,其中可能帶有一絲結構,我們可以將其輸入模型以獲得新的預測。希望這個新的預測比第一個稍好一些(因為我們的起點噪聲稍小),這樣我們就可以用這個新的、更好的預測再邁出一小步。
重複幾次,如果一切順利,我們就能得到一張影像!這裡展示了這個過程僅用 5 個步驟的圖示,可視化了每個階段模型的輸入(左)和預測的去噪影像(右)。請注意,即使模型在第一步就預測了去噪影像,我們也只讓 x 朝那個方向移動了一部分。經過幾步,結構出現並得到完善,直到我們得到最終的輸出。
>>> # @markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
>>> n_steps = 5
>>> x = torch.rand(8, 1, 28, 28).to(device) # Start from random
>>> step_history = [x.detach().cpu()]
>>> pred_output_history = []
>>> for i in range(n_steps):
... with torch.no_grad(): # No need to track gradients during inference
... pred = net(x) # Predict the denoised x0
... pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
... mix_factor = 1 / (n_steps - i) # How much we move towards the prediction
... x = x * (1 - mix_factor) + pred * mix_factor # Move part of the way there
... step_history.append(x.detach().cpu()) # Store step for plotting
>>> fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
>>> axs[0, 0].set_title("x (model input)")
>>> axs[0, 1].set_title("model prediction")
>>> for i in range(n_steps):
... axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap="Greys")
... axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap="Greys")我們可以將過程分成更多的步驟,希望能得到更好的影像。
>>> # @markdown Showing more results, using 40 sampling steps
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x)
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> fig, ax = plt.subplots(1, 1, figsize=(12, 12))
>>> ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")效果不是很好,但能看到一些可識別的數字!你可以嘗試訓練更長時間(比如 10 或 20 個 epoch),並調整模型配置、學習率、最佳化器等。另外,別忘了,如果你想嘗試一個稍微難一點的資料集,FashionMNIST 只需要一行程式碼就能替換。
與 DDPM 的比較
在這一節中,我們將看看我們的玩具級實現與另一個筆記本(Diffusers 簡介)中使用的方法有何不同,後者是基於 DDPM 論文的。
我們將看到:
- diffusers 的
UNet2DModel比我們的 BasicUNet 要先進一些 - 破壞過程的處理方式不同
- 訓練目標不同,涉及預測噪聲而不是去噪後的影像
- 模型透過時間步條件來適應噪聲量,其中 t 作為額外的引數傳遞給前向方法。
- 有多種不同的取樣策略可用,它們應該比我們上面簡單的版本效果更好。
自 DDPM 論文發表以來,已經提出了許多改進,但希望這個例子能有助於說明可用的不同設計決策。讀完這部分後,你可能會喜歡深入研究論文 ‘Elucidating the Design Space of Diffusion-Based Generative Models’,它詳細探討了所有這些元件,併為如何獲得最佳效能提出了新的建議。
如果所有這些內容對你來說太技術性或令人生畏,別擔心!可以隨意跳過本筆記本的其餘部分,或者留到某個閒暇的日子再看。
UNet
diffusers 的 UNet2DModel 模型比我們上面基礎的 UNet 有許多改進:
- GroupNorm 對每個塊的輸入應用組歸一化
- Dropout 層以實現更平滑的訓練
- 每個塊有多個 resnet 層(如果
layers_per_block不設定為 1) - 注意力機制(通常僅在較低解析度的塊中使用)
- 基於時間步的條件化。
- 帶有可學習引數的下采樣和上取樣塊
讓我們建立一個並檢查一個 UNet2DModel
>>> model = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... )
>>> print(model)UNet2DModel(
(conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=32, out_features=128, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=128, out_features=128, bias=True)
)
(down_blocks): ModuleList(
(0): DownBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(1): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(2): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(up_blocks): ModuleList(
(0): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(1): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(2): UpBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
如你所見,這裡的內容要多一些!它的引數也比我們的 BasicUNet 多得多。
sum([p.numel() for p in model.parameters()]) # 1.7M vs the ~309k parameters of the BasicUNet我們可以用這個模型替代我們原來的模型來複現上面的訓練過程。我們需要將 x 和 timestep 都傳遞給模型(這裡我總是傳遞 t=0 來展示即使沒有時間步條件它也能工作,並保持取樣程式碼簡單,但你也可以嘗試傳入 (amount*1000) 來從破壞量中獲得一個等效的時間步)。如果你想檢查程式碼,已更改的行用 #<<< 標記。
>>> # @markdown Trying UNet2DModel instead of BasicUNet:
>>> # Dataloader (you can mess with batch size)
>>> batch_size = 128
>>> train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
>>> # How many runs through the data should we do?
>>> n_epochs = 3
>>> # Create the network
>>> net = UNet2DModel(
... sample_size=28, # the target image resolution
... in_channels=1, # the number of input channels, 3 for RGB images
... out_channels=1, # the number of output channels
... layers_per_block=2, # how many ResNet layers to use per UNet block
... block_out_channels=(32, 64, 64), # Roughly matching our basic unet example
... down_block_types=(
... "DownBlock2D", # a regular ResNet downsampling block
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
... "AttnDownBlock2D",
... ),
... up_block_types=(
... "AttnUpBlock2D",
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
... "UpBlock2D", # a regular ResNet upsampling block
... ),
... ) # <<<
>>> net.to(device)
>>> # Our loss finction
>>> loss_fn = nn.MSELoss()
>>> # The optimizer
>>> opt = torch.optim.Adam(net.parameters(), lr=1e-3)
>>> # Keeping a record of the losses for later viewing
>>> losses = []
>>> # The training loop
>>> for epoch in range(n_epochs):
... for x, y in train_dataloader:
... # Get some data and prepare the corrupted version
... x = x.to(device) # Data on the GPU
... noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
... noisy_x = corrupt(x, noise_amount) # Create our noisy x
... # Get the model prediction
... pred = net(noisy_x, 0).sample # <<< Using timestep 0 always, adding .sample
... # Calculate the loss
... loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
... # Backprop and update the params:
... opt.zero_grad()
... loss.backward()
... opt.step()
... # Store the loss for later
... losses.append(loss.item())
... # Print our the average of the loss values for this epoch:
... avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
... print(f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}")
>>> # Plot losses and some samples
>>> fig, axs = plt.subplots(1, 2, figsize=(12, 5))
>>> # Losses
>>> axs[0].plot(losses)
>>> axs[0].set_ylim(0, 0.1)
>>> axs[0].set_title("Loss over time")
>>> # Samples
>>> n_steps = 40
>>> x = torch.rand(64, 1, 28, 28).to(device)
>>> for i in range(n_steps):
... noise_amount = torch.ones((x.shape[0],)).to(device) * (1 - (i / n_steps)) # Starting high going low
... with torch.no_grad():
... pred = net(x, 0).sample
... mix_factor = 1 / (n_steps - i)
... x = x * (1 - mix_factor) + pred * mix_factor
>>> axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap="Greys")
>>> axs[1].set_title("Generated Samples")Finished epoch 0. Average loss for this epoch: 0.018925 Finished epoch 1. Average loss for this epoch: 0.012785 Finished epoch 2. Average loss for this epoch: 0.011694
這看起來比我們第一組結果好多了!你可以嘗試調整 unet 配置或訓練更長時間以獲得更好的效能。
破壞過程
DDPM 論文描述了一個在每個“時間步”新增少量噪聲的破壞過程。給定某個時間步的 $x_{t-1}$,我們可以透過以下方式得到下一個(稍微更嘈雜)的版本 $x_t$:
$q(\mathbf{x}t \vert \mathbf{x}{t-1}) = \mathcal{N}(\mathbf{x}t; \sqrt{1 - \beta_t} \mathbf{x}{t-1}, \betat\mathbf{I}) \quad q(\mathbf{x}{1:T} \vert \mathbf{x}0) = \prod^T{t=1} q(\mathbf{x}t \vert \mathbf{x}{t-1})$
也就是說,我們將 $x{t-1}$ 乘以 $\sqrt{1 - \beta_t}$ 並加上乘以 $\beta_t$ 的噪聲。這個 $\beta$ 是根據某個排程為每個 t 定義的,它決定了每個時間步新增多少噪聲。現在,我們不一定想為了得到 $x{500}$ 而進行 500 次這個操作,所以我們有另一個公式可以在給定 $x_0$ 的情況下得到任意 t 的 $x_t$:
$\begin{aligned} q(\mathbf{x}t \vert \mathbf{x}_0) &= \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, \sqrt{(1 - \bar{\alpha}_t)} \mathbf{I}) \end{aligned}$ 其中 $\bar{\alpha}_t = \prod{i=1}^T \alpha_i$ 且 $\alpha_i = 1-\beta_i$
數學符號總是看起來很嚇人!幸運的是,排程器為我們處理了所有這些(取消註釋下一個單元格以檢視程式碼)。我們可以繪製 $\sqrt{\bar{\alpha}_t}$(標記為 sqrt_alpha_prod)和 $\sqrt{(1 - \bar{\alpha}_t)}$(標記為 sqrt_one_minus_alpha_prod)來觀察在不同時間步輸入 (x) 和噪聲是如何縮放和混合的。
# ??noise_scheduler.add_noise>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
>>> plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
>>> plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
>>> plt.legend(fontsize="x-large")最初,帶噪的 x 主要是 x(sqrt_alpha_prod ~= 1),但隨著時間的推移,x 的貢獻下降,噪聲成分增加。與我們根據 amount 線性混合 x 和噪聲不同,這個方法相對較快地變得嘈雜。我們可以在一些資料上視覺化這一點。
>>> # @markdown visualize the DDPM noising process for different timesteps:
>>> # Noise a batch of images to view the effect
>>> fig, axs = plt.subplots(3, 1, figsize=(16, 10))
>>> xb, yb = next(iter(train_dataloader))
>>> xb = xb.to(device)[:8]
>>> xb = xb * 2.0 - 1.0 # Map to (-1, 1)
>>> print("X shape", xb.shape)
>>> # Show clean inputs
>>> axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[0].set_title("Clean X")
>>> # Add noise with scheduler
>>> timesteps = torch.linspace(0, 999, 8).long().to(device)
>>> noise = torch.randn_like(xb) # << NB: randn not rand
>>> noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
>>> print("Noisy X shape", noisy_xb.shape)
>>> # Show noisy version (with and without clipping)
>>> axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap="Greys")
>>> axs[1].set_title("Noisy X (clipped to (-1, 1)")
>>> axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap="Greys")
>>> axs[2].set_title("Noisy X")X shape torch.Size([8, 1, 28, 28]) Noisy X shape torch.Size([8, 1, 28, 28])
另一個動態是:DDPM 版本新增的是從高斯分佈(均值為 0,標準差為 1,來自 `torch.randn`)中抽取的噪聲,而不是我們在原始 `corrupt` 函式中使用的 0 到 1 之間的均勻噪聲(來自 `torch.rand`)。總的來說,對訓練資料進行歸一化也是有意義的。在另一個筆記本中,你會在變換列表中看到 `Normalize(0.5, 0.5)`,它將影像資料從 (0, 1) 對映到 (-1, 1),這對我們的目的來說“足夠好”了。我們在這個筆記本中沒有這樣做,但上面的視覺化單元格中加入了它,以便進行更準確的縮放和視覺化。
訓練目標
在我們的玩具示例中,我們讓模型嘗試預測去噪後的影像。在 DDPM 和許多其他擴散模型的實現中,模型預測的是破壞過程中使用的噪聲(在縮放之前,即單位方差噪聲)。在程式碼中,它看起來像這樣:
noise = torch.randn_like(xb) # << NB: randn not rand
noisy_x = noise_scheduler.add_noise(x, noise, timesteps)
model_prediction = model(noisy_x, timesteps).sample
loss = mse_loss(model_prediction, noise) # noise as the target你可能會認為,預測噪聲(從中我們可以推匯出降噪後的影像是什麼樣子)等同於直接預測降噪後的影像。那麼,為什麼偏愛一種而不是另一種呢——僅僅是為了數學上的方便嗎?
事實證明,這裡還有另一個微妙之處。我們在訓練過程中對不同(隨機選擇)的時間步計算損失。這些不同的目標將導致對這些損失進行不同的“隱式加權”,其中預測噪聲會更側重於較低的噪聲水平。你可以選擇更復雜的目標來改變這種“隱式損失加權”。或者,你可以選擇一個噪聲排程,使得在較高噪聲水平下有更多的樣本。或者,你可以讓模型預測一個“速度”v,我們將其定義為依賴於噪聲水平的影像和噪聲的組合(參見 ‘PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS’)。或者,你可以讓模型預測噪聲,然後根據一些理論(參見 ‘Perception Prioritized Training of Diffusion Models’)或基於實驗(參見 ‘Elucidating the Design Space of Diffusion-Based Generative Models’)來確定哪些噪聲水平對模型最有資訊量,從而根據噪聲量用某個因子來縮放損失。總而言之:選擇目標對模型效能有影響,關於什麼是“最佳”選項的研究正在進行中。
目前,預測噪聲(在某些地方你會看到 epsilon 或 eps)是首選方法,但隨著時間的推移,我們可能會看到庫中支援其他目標,並在不同情況下使用。
時間步條件
UNet2DModel 同時接收 x 和 timestep 作為輸入。後者被轉換成一個嵌入,並被輸入到模型的多個位置。
這背後的理論是,透過給模型提供關於噪聲水平的資訊,它可以更好地執行其任務。雖然在沒有這種時間步條件的情況下訓練模型是可能的,但在某些情況下它確實有助於提高效能,並且大多數實現都包含了它,至少在當前的文獻中是這樣。
取樣
給定一個能估計帶噪輸入中噪聲(或預測去噪版本)的模型,我們如何生成新的影像?
我們可以輸入純噪聲,並希望模型一步就能預測出一個好的影像作為去噪版本。然而,正如我們在上面的實驗中看到的,這通常效果不佳。因此,我們採取一系列基於模型預測的小步驟,逐步地、一次去除一點點噪聲。
具體如何採取這些步驟取決於所使用的取樣方法。我們不會深入探討理論,但一些關鍵的設計問題是:
- 你應該採取多大的步長?換句話說,你應該遵循什麼樣的“噪聲排程”?
- 你是否只使用模型當前的預測來指導更新步驟(像 DDPM、DDIM 和許多其他方法)?你是否多次評估模型以估計更高階的梯度,從而實現更大、更準確的步長(高階方法和一些離散 ODE 求解器)?或者,你是否保留過去預測的歷史記錄,以更好地指導當前的更新步驟(線性多步法和祖先採樣器)?
- 你是否加入額外的噪聲(有時稱為 churn)來增加取樣過程的隨機性,還是保持其完全確定性?許多采樣器透過一個引數(如 DDIM 取樣器的‘eta’)來控制這一點,以便使用者可以選擇。
關於擴散模型取樣方法的研究正在迅速發展,越來越多能夠在更少步驟內找到好解的方法被提出來。勇敢和好奇的讀者可能會有興趣瀏覽 diffusers 庫中不同實現的原始碼這裡,或檢視文件,文件中通常會連結到相關的論文。
結論
希望這能幫助你從一個稍微不同的角度來看待擴散模型。
本筆記本由 Jonathan Whitaker 為 Hugging Face 課程編寫,並與他自己的課程 ‘The Generative Landscape’ 中的一個版本有重疊。如果你想看到這個基礎示例擴充套件到噪聲和類別條件,可以去看看那個課程。問題或錯誤可以透過 GitHub issues 或 Discord 交流。也歡迎透過 Twitter @johnowhitaker 聯絡我。
< > 在 GitHub 上更新