Transformers 文件

DiT

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

PyTorch Flax

DiT

DiT 是一種影像 Transformer 模型,在大規模未標註文件影像上進行預訓練。它學習從損壞的輸入影像中預測缺失的視覺標記。預訓練的 DiT 模型可用作其他模型的骨幹網路,用於文件影像分類和表格檢測等視覺文件任務。

你可以在 Microsoft 組織下找到所有原始的 DiT 檢查點。

請參閱 BEiT 文件,瞭解如何將 DiT 應用於不同視覺任務的更多示例。

以下示例展示瞭如何使用 PipelineAutoModel 類對影像進行分類。

<hfoptions id="usage"> <hfoption id="Pipeline">
import torch
from transformers import pipeline

pipeline = pipeline(
    task="image-classification",
    model="microsoft/dit-base-finetuned-rvlcdip",
    torch_dtype=torch.float16,
    device=0
)
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg")
</hfoption> <hfoption id="AutoModel">
import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(
    "microsoft/dit-base-finetuned-rvlcdip",
    use_fast=True,
)
model = AutoModelForImageClassification.from_pretrained(
    "microsoft/dit-base-finetuned-rvlcdip",
    device_map="auto",
)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dit-example.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(image, return_tensors="pt").to("cuda")

with torch.no_grad():
  logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()

class_labels = model.config.id2label
predicted_class_label = class_labels[predicted_class_id]
print(f"The predicted class label is: {predicted_class_label}")
</hfoption>

注意事項

  • 預訓練的 DiT 權重可以載入到帶有建模頭的 [BEiT] 模型中以預測視覺標記。
    from transformers import BeitForMaskedImageModeling
    
    model = BeitForMaskedImageModeling.from_pretraining("microsoft/dit-base")

資源

  • 有關文件影像分類推理的示例,請參閱此筆記本
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.