timm 文件
快速入門
並獲得增強的文件體驗
開始使用
快速入門
本快速入門旨在幫助開發者深入瞭解程式碼,並提供一個如何將 timm
整合到模型訓練工作流中的示例。
首先,你需要安裝 timm
。有關安裝的更多資訊,請參閱 安裝。
pip install timm
載入預訓練模型
可以使用 create_model() 載入預訓練模型。
在這裡,我們載入預訓練的 mobilenetv3_large_100
模型。
>>> import timm
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
列出帶預訓練權重的模型
要列出 timm
中打包的模型,你可以使用 list_models()。如果指定 pretrained=True
,此函式將只返回具有可用預訓練權重的模型名稱。
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
'adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
]
你還可以列出名稱中包含特定模式的模型。
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
'cspresnet50',
'cspresnet50d',
'cspresnet50w',
'cspresnext50',
...
]
微調預訓練模型
你可以透過更改分類器(最後一層)來微調任何預訓練模型。
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
要在你自己的資料集上進行微調,你必須編寫一個 PyTorch 訓練迴圈或調整 timm
的 訓練指令碼 來使用你的資料集。
使用預訓練模型進行特徵提取
在不修改網路的情況下,可以在任何模型上呼叫 model.forward_features(input) 而不是通常的 model(input)。這將繞過網路的頭部自分類器和全域性池化層。
有關使用 timm
進行特徵提取的更深入指南,請參閱 特徵提取。
>>> import timm
>>> import torch
>>> x = torch.randn(1, 3, 224, 224)
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 960, 7, 7])
影像增強
要將影像轉換為模型的有效輸入,你可以使用 timm.data.create_transform(),並提供模型期望的 input_size
。
這將返回一個使用合理預設值的通用轉換。
>>> timm.data.create_transform((3, 224, 224))
Compose(
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
預訓練模型在訓練時對輸入影像應用了特定的轉換。如果你對影像使用了錯誤的轉換,模型將無法理解它看到的內容!
要弄清楚給定的預訓練模型使用了哪些轉換,我們可以先檢視其 pretrained_cfg
>>> model.pretrained_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv_stem',
'classifier': 'classifier',
'architecture': 'mobilenetv3_large_100'}
然後,我們可以使用 timm.data.resolve_data_config() 來僅解析與資料相關的配置。
>>> timm.data.resolve_data_config(model.pretrained_cfg)
{'input_size': (3, 224, 224),
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'crop_pct': 0.875}
我們可以將此資料配置傳遞給 timm.data.create_transform() 來初始化模型關聯的轉換。
>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
>>> transform = timm.data.create_transform(**data_cfg)
>>> transform
Compose(
Resize(size=256, interpolation=bicubic, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
使用預訓練模型進行推理
在這裡,我們將整合以上各部分,並使用一個預訓練模型進行推理。
首先,我們需要一張影像來進行推理。這裡我們從網上載入一張葉子的圖片。
>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image
這是我們載入的影像

現在,我們將再次建立我們的模型和轉換。這次,我們確保將模型設定為評估模式。
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
>>> transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
我們可以透過將影像傳遞給轉換來為模型準備好這張圖片。
>>> image_tensor = transform(image)
>>> image_tensor.shape
torch.Size([3, 224, 224])
現在我們可以將該影像傳遞給模型以獲得預測。在這種情況下,我們使用 unsqueeze(0)
,因為模型期望一個批次維度。
>>> output = model(image_tensor.unsqueeze(0))
>>> output.shape
torch.Size([1, 1000])
要獲得預測機率,我們對輸出應用 softmax。這將給我們留下一個形狀為 (num_classes,)
的張量。
>>> probabilities = torch.nn.functional.softmax(output[0], dim=0)
>>> probabilities.shape
torch.Size([1000])
現在,我們將使用 torch.topk
找到前 5 個預測的類別索引和值。
>>> values, indices = torch.topk(probabilities, 5)
>>> indices
tensor([281, 282, 285, 673, 670])
如果我們檢查最高索引對應的 imagenet 標籤,我們就可以看到模型預測了什麼……
>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
[{'label': 'tabby, tabby_cat', 'value': 0.5101025700569153},
{'label': 'tiger_cat', 'value': 0.22490699589252472},
{'label': 'Egyptian_cat', 'value': 0.1835290789604187},
{'label': 'mouse, computer_mouse', 'value': 0.006752475164830685},
{'label': 'motor_scooter, scooter', 'value': 0.004942195490002632}]