Tuesday, December 29, 2020

實作:簡易版 BERT Transformers Multi-Label Classification

簡介
由 Huggingface 所開發的 Transformers Library,雖然可以用 BERT 做 NLP Multi-Class Classification(每一組數據擁有一個 class),但卻未有支援 Multi-Label Classification(每一組數據,可以有多於一個 class)。這次實作,我會以 Kaggle Jigsaw Toxic Comments 作為例子,透過小改 Transformers 的 Multi-Class Classification 範例,令它可以進行 Multi-Label 分類。

---

改進原理
Transformers 中的 BertForSequenceClassification,是一個基於 BERT 及 Attention Model,然後再使用 CrossEntropyLoss 微調訓練的模型。事實上,它也支援自動訓練:只要你使用期間提供了 label,它內置的 CrossEntropyLoss 就能自動計算並返回 loss 值。

這種「黑箱全包」的做法,雖然令入門開發者,可以在不了解 Loss Function 等概念的情況下上手,但同時,它也帶來了一些限制:由於它使用了 CrossEntropyLoss,它就只能支援一個類別的分類。舉例來講,它只能計算出「 預計矩陣 [0 0 0 1] 與期望值 3(也就是 [0 0 0 1] )的距離為 0 」。它並不能輸入多於一個維度的期望值。

所以,這次實作,我只是簡單地改了一下:不用內置於 BertForSequenceClassification 的 CrossEntropyLoss,改為採用 BCEWithLogitsLoss。由於 BCEWithLogitsLoss 本身就是專為 Multi-Label Classification 而設計的 Loss Function,所以我也懶得仔細查究,直接就拿來用。最後,我們就能把 BertForSequenceClassification 改成支援 Multi-Label Classification 的模型。

至於其他,例如 BertTokenizer 等等,則採用 Pretrained 結果,原封不動。

---

重點代碼
這幾行代碼,就是這次改動的大重點:

在訓練過程中:

# use AdamW
optimizer = optim.AdamW(model.parameters(), lr=lr)

# generate prediction
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask)
 
# compute gradients and update weights
loss = nn.BCEWithLogitsLoss(outputs.logits, labels)
loss.backward()
optimizer.step()

解說:

  • 使用 AdamW 做 Optimizer(其實其他都可以,只是我發現 AdamW 比較快)。
  • 使用 BertForSequenceClassification model 的過程中,故意不提供 label,
    這樣,Transformers 就不會偷偷地自己計算 Loss 和 Gradients,必須要自己動手。
  • 使用 BCEWithLogitsLoss 來計算 Loss,然後手動做 Backward Propagation 和 Stepping。

 在驗證測試過程中:

# generate prediction
outputs = model(input_ids, attention_mask=attention_mask)
prob = outputs.logits.sigmoid()

# record processed data count
total += (labels.size(0)*labels.size(1))

解說:

  • 用 model 計算結果後,必須經過一層 sigmoid 計算。
    因為 BCEWithLogitsLoss 入面都有 sigmoid,這樣才會一致。
  • 在計算總準繩率時,不能單一以 labels.size(0)(樣本數)做分母,
    相反,要以總標記數 (樣本數 * 每個樣本內的標記數) 做分母。
    因為:假設 Predict = [ 1 0 0 1 ], Expected = [ 1 0 0 0 ],
    它應該算做 3/4 正確,而非一個錯誤。

---

演示代碼
https://github.com/cmcvista/BERTClassifier/blob/main/BertMultiLabelClassifier.ipynb

---

小記
感謝 Joseph 提示方向。

---

參考

[1] - https://huggingface.co/transformers/training.html

[2] - https://towardsdatascience.com/transformers-for-multilabel-classification-71a1a0daf5e1

[3] - https://medium.com/huggingface/multi-label-text-classification-using-bert-the-mighty-transformer-69714fa3fb3d