如何利用對抗性資料動態訓練模型
您將在這裡學到什麼
- 💡動態對抗性資料收集的基本概念及其重要性。
- ⚒ 如何動態收集對抗性資料並在其上訓練您的模型——以 MNIST 手寫數字識別任務為例。
動態對抗性資料收集 (DADC)
靜態基準雖然是評估模型效能的廣泛使用方法,但存在許多問題:它們飽和、存在偏差或漏洞,並且經常導致研究人員追求指標的增加,而不是構建可供人類使用的值得信賴的模型1。
動態對抗性資料收集(DADC)作為一種緩解靜態基準部分問題的方法,前景廣闊。在 DADC 中,人類建立示例來**欺騙**最先進(SOTA)模型。這個過程提供兩個好處:
- 它允許使用者評估其模型的真實魯棒性;
- 它產生的資料可以用於進一步訓練更強大的模型。
這種在對抗性收集的資料上欺騙和訓練模型的過程會重複多輪,從而產生一個更魯棒且與人類對齊的模型1 。
利用對抗性資料動態訓練您的模型
在這裡,我將向您展示如何動態地從使用者那裡收集對抗性資料並根據它們訓練您的模型——以 MNIST 手寫數字識別任務為例。
在 MNIST 手寫數字識別任務中,模型經過訓練,可以根據手寫數字的 28x28
灰度影像輸入(參見下圖中的示例)預測數字。數字範圍從 0 到 9。
這項任務被廣泛認為是計算機視覺的“hello world”,並且很容易訓練模型在標準(和靜態)基準測試集上實現高精度。然而,已經表明這些 SOTA 模型在人類書寫數字(並將其作為輸入提供給模型)時仍然難以預測正確的數字:研究人員認為這很大程度上是因為靜態測試集不能充分代表人類書寫方式的非常多樣性。因此,需要人類參與迴圈,為模型提供**對抗性**樣本,這將有助於它們更好地泛化。
本教程將分為以下幾個部分:
- 配置您的模型
- 與您的模型互動
- 標記您的模型
- 整合所有部分
配置您的模型
首先,您需要定義您的模型架構。我簡單的模型架構由兩個卷積網路組成,它們連線到一個 50 維的全連線層和一個用於 10 個類別的最終層。最後,我們使用 softmax 啟用函式將模型的輸出轉換為類別上的機率分佈。
# Adapted from: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
現在您已經定義了模型的結構,您需要將其在標準 MNIST 訓練/開發資料集上進行訓練。
與您的模型互動
至此,我們假設您已經訓練好了模型。儘管該模型已經訓練,但我們旨在透過人機協作對抗性資料使其變得健壯。為此,您需要一種使用者與模型互動的方式:具體來說,您希望使用者能夠在畫布上書寫/繪製數字 0-9,並讓模型嘗試對其進行分類。您可以使用 🤗 Spaces 完成所有這些操作,它允許您快速輕鬆地為您的機器學習模型構建演示。在此處瞭解有關 Spaces 以及如何構建它們的更多資訊here。
下面是一個簡單的 Space,用於與我訓練了 20 個 epoch 的 `MNIST_Model` 互動(在測試集上獲得了 89% 的準確率)。您在白色畫布上繪製一個數字,模型會從您的影像中預測該數字。完整的 Space 可以在此處訪問。嘗試欺騙這個模型😁。使用您最有趣的筆跡;在畫布的側面書寫;盡情發揮吧!
標記您的模型
您能騙過上面的模型嗎?😀 如果能,那麼是時候**標記**您的對抗性示例了。標記包括:
- 將對抗性示例儲存到資料集
- 在收集到一定數量的樣本後,對對抗性示例進行模型訓練。
- 重複步驟 1-2 若干次。
我編寫了一個自定義的 `flag` 函式來完成所有這些操作。有關更多詳細資訊,請隨時在此處查閱完整程式碼。
注意:Gradio 有一個內建的標記回撥,可以讓您輕鬆標記模型的對抗性樣本。在此處閱讀更多相關資訊:here。
將所有內容整合
最後一步是將所有三個元件(配置模型、與模型互動和標記模型)整合到一個演示空間中!為此,我建立了 MNIST 對抗性空間,用於 MNIST 手寫識別任務的動態對抗性資料收集。請隨意在下面進行測試。
結論
動態對抗性資料收集 (DADC) 在機器學習社群中作為一種收集多樣化、非飽和、與人類對齊的資料集、改進模型評估和任務效能的方式,正獲得越來越多的關注。透過動態收集帶有模型迴圈的人工生成的對抗性資料,我們可以提高模型的泛化潛力。
這種在對抗性收集的資料上欺騙和訓練模型的過程應該重複多輪1。 Eric Wallace 等人在自然語言推理任務的實驗中表明,儘管短期內標準非對抗性資料收集表現更好,但從長遠來看,動態對抗性資料收集帶來了顯著更高的準確率。
使用 🤗 Spaces,構建一個平臺來動態收集模型的對抗性資料並進行訓練變得相對容易。