開源 AI 食譜文件

在自定義資料集上微調語義分割模型並透過推理API使用

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Open In Colab

在自定義資料集上微調語義分割模型並透過推理API使用

作者: Sergio Paniego

在本Notebook中,我們將詳細介紹如何在自定義資料集上微調語義分割模型。我們將使用的模型是預訓練的Segformer,這是一種強大而靈活的基於Transformer的架構,適用於分割任務。

Segformer architecture

對於我們的資料集,我們將使用segments/sidewalk-semantic,其中包含帶有人行道標籤的影像,非常適合城市環境中的應用。

用例示例:該模型可以部署在自動導航人行道送披薩到您家門口的送貨機器人中🍕

微調模型後,我們將演示如何使用無伺服器推理API進行部署,使其可以透過簡單的API端點訪問。

1. 安裝依賴

首先,我們將安裝微調語義分割模型所需的基本庫。

!pip install -q datasets transformers evaluate wandb
# Tested with datasets==3.0.0, transformers==4.44.2, evaluate==0.4.3, wandb==0.18.1

2. 載入資料集 📁

我們將使用sidewalk-semantic資料集,該資料集包含2021年夏季在比利時收集的人行道影像。

資料集包括:

  • 1,000張影像及其對應的語義分割掩碼 🖼
  • 34個不同類別 📦

由於此資料集受到限制,您需要登入並接受許可才能訪問。我們還需要身份驗證才能在訓練後將微調模型上傳到Hub。

from huggingface_hub import notebook_login

notebook_login()
sidewalk_dataset_identifier = "segments/sidewalk-semantic"
from datasets import load_dataset

dataset = load_dataset(sidewalk_dataset_identifier)

檢視內部結構以熟悉它!

dataset

由於資料集只包含訓練集,我們將手動將其分為訓練集和測試集。我們將80%的資料用於訓練,其餘20%用於評估和測試。➗

dataset = dataset.shuffle(seed=42)
dataset = dataset["train"].train_test_split(test_size=0.2)
train_ds = dataset["train"]
test_ds = dataset["test"]

讓我們檢查一個示例中存在的物件型別。我們可以看到pixels_values包含RGB影像,而label包含真實掩碼。掩碼是一個單通道影像,其中每個畫素代表RGB影像中對應畫素的類別。

image = train_ds[0]
image

3. 視覺化示例!👀

現在我們已經載入了資料集,讓我們視覺化一些示例及其掩碼,以便更好地理解其結構。

資料集包含一個JSON檔案,其中包含id2label對映。我們將開啟此檔案以讀取與每個ID關聯的類別標籤。

>>> import json
>>> from huggingface_hub import hf_hub_download

>>> filename = "id2label.json"
>>> id2label = json.load(
...     open(hf_hub_download(repo_id=sidewalk_dataset_identifier, filename=filename, repo_type="dataset"), "r")
... )
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}

>>> num_labels = len(id2label)
>>> print("Id2label:", id2label)
Id2label: {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}

讓我們為每個類別分配顏色🎨。這將有助於我們更有效地視覺化分割結果,並更輕鬆地解釋影像中的不同類別。

sidewalk_palette = [
    [0, 0, 0],  # unlabeled
    [216, 82, 24],  # flat-road
    [255, 255, 0],  # flat-sidewalk
    [125, 46, 141],  # flat-crosswalk
    [118, 171, 47],  # flat-cyclinglane
    [161, 19, 46],  # flat-parkingdriveway
    [255, 0, 0],  # flat-railtrack
    [0, 128, 128],  # flat-curb
    [190, 190, 0],  # human-person
    [0, 255, 0],  # human-rider
    [0, 0, 255],  # vehicle-car
    [170, 0, 255],  # vehicle-truck
    [84, 84, 0],  # vehicle-bus
    [84, 170, 0],  # vehicle-tramtrain
    [84, 255, 0],  # vehicle-motorcycle
    [170, 84, 0],  # vehicle-bicycle
    [170, 170, 0],  # vehicle-caravan
    [170, 255, 0],  # vehicle-cartrailer
    [255, 84, 0],  # construction-building
    [255, 170, 0],  # construction-door
    [255, 255, 0],  # construction-wall
    [33, 138, 200],  # construction-fenceguardrail
    [0, 170, 127],  # construction-bridge
    [0, 255, 127],  # construction-tunnel
    [84, 0, 127],  # construction-stairs
    [84, 84, 127],  # object-pole
    [84, 170, 127],  # object-trafficsign
    [84, 255, 127],  # object-trafficlight
    [170, 0, 127],  # nature-vegetation
    [170, 84, 127],  # nature-terrain
    [170, 170, 127],  # sky
    [170, 255, 127],  # void-ground
    [255, 0, 127],  # void-dynamic
    [255, 84, 127],  # void-static
    [255, 170, 127],  # void-unclear
]

我們可以視覺化資料集中的一些示例,包括RGB影像、對應的掩碼以及掩碼在影像上的疊加。這將幫助我們更好地理解資料集以及掩碼如何與影像對應。📸

>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from PIL import Image
>>> import matplotlib.patches as patches

>>> # Create and show the legend separately
>>> fig, ax = plt.subplots(figsize=(18, 2))

>>> legend_patches = [
...     patches.Patch(color=np.array(color) / 255, label=label)
...     for label, color in zip(id2label.values(), sidewalk_palette)
... ]

>>> ax.legend(handles=legend_patches, loc="center", bbox_to_anchor=(0.5, 0.5), ncol=5, fontsize=8)
>>> ax.axis("off")

>>> plt.show()

>>> for i in range(5):
...     image = train_ds[i]

...     fig, ax = plt.subplots(1, 3, figsize=(18, 6))

...     # Show the original image
...     ax[0].imshow(image["pixel_values"])
...     ax[0].set_title("Original Image")
...     ax[0].axis("off")

...     mask_np = np.array(image["label"])

...     # Create a new empty RGB image
...     colored_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)

...     # Assign colors to each value in the mask
...     for label_id, color in enumerate(sidewalk_palette):
...         colored_mask[mask_np == label_id] = color

...     colored_mask_img = Image.fromarray(colored_mask, "RGB")

...     # Show the segmentation mask
...     ax[1].imshow(colored_mask_img)
...     ax[1].set_title("Segmentation Mask")
...     ax[1].axis("off")

...     # Convert the original image to RGBA to support transparency
...     image_rgba = image["pixel_values"].convert("RGBA")
...     colored_mask_rgba = colored_mask_img.convert("RGBA")

...     # Adjust transparency of the mask
...     alpha = 128  # Transparency level (0 fully transparent, 255 fully opaque)
...     image_2_with_alpha = Image.new("RGBA", colored_mask_rgba.size)
...     for x in range(colored_mask_rgba.width):
...         for y in range(colored_mask_rgba.height):
...             r, g, b, a = colored_mask_rgba.getpixel((x, y))
...             image_2_with_alpha.putpixel((x, y), (r, g, b, alpha))

...     superposed = Image.alpha_composite(image_rgba, image_2_with_alpha)

...     # Show the mask overlay
...     ax[2].imshow(superposed)
...     ax[2].set_title("Mask Overlay")
...     ax[2].axis("off")

...     plt.show()

4. 視覺化類別出現次數 📊

為了更深入地瞭解資料集,讓我們繪製每個類別的出現次數。這將幫助我們瞭解類別的分佈,並識別資料集中潛在的偏差或不平衡。

import matplotlib.pyplot as plt
import numpy as np

class_counts = np.zeros(len(id2label))

for example in train_ds:
    mask_np = np.array(example["label"])
    unique, counts = np.unique(mask_np, return_counts=True)
    for u, c in zip(unique, counts):
        class_counts[u] += c
>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from matplotlib import patches

>>> labels = list(id2label.values())

>>> # Normalize colors to be in the range [0, 1]
>>> normalized_palette = [tuple(c / 255 for c in color) for color in sidewalk_palette]

>>> # Visualization
>>> fig, ax = plt.subplots(figsize=(12, 8))

>>> bars = ax.bar(range(len(labels)), class_counts, color=[normalized_palette[i] for i in range(len(labels))])

>>> ax.set_xticks(range(len(labels)))
>>> ax.set_xticklabels(labels, rotation=90, ha="right")

>>> ax.set_xlabel("Categories", fontsize=14)
>>> ax.set_ylabel("Number of Occurrences", fontsize=14)
>>> ax.set_title("Number of Occurrences by Category", fontsize=16)

>>> ax.grid(axis="y", linestyle="--", alpha=0.7)

>>> # Adjust the y-axis limit
>>> y_max = max(class_counts)
>>> ax.set_ylim(0, y_max * 1.25)

>>> for bar in bars:
...     height = int(bar.get_height())
...     offset = 10  # Adjust the text location
...     ax.text(
...         bar.get_x() + bar.get_width() / 2.0,
...         height + offset,
...         f"{height}",
...         ha="center",
...         va="bottom",
...         rotation=90,
...         fontsize=10,
...         color="black",
...     )

>>> fig.legend(
...     handles=legend_patches, loc="center left", bbox_to_anchor=(1, 0.5), ncol=1, fontsize=8
... )  # Adjust ncol as needed

>>> plt.tight_layout()
>>> plt.show()

5. 初始化影像處理器並使用Albumentations新增資料增強 📸

我們將首先初始化影像處理器,然後使用Albumentations應用資料增強🪄。這將有助於增強我們的資料集並提高語義分割模型的效能。

import albumentations as A
from transformers import SegformerImageProcessor

image_processor = SegformerImageProcessor()

albumentations_transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.7),
        A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.5),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=25, val_shift_limit=20, p=0.5),
        A.GaussianBlur(blur_limit=(3, 5), p=0.3),
        A.GaussNoise(var_limit=(10, 50), p=0.4),
    ]
)


def train_transforms(example_batch):
    augmented = [
        albumentations_transform(image=np.array(image), mask=np.array(label))
        for image, label in zip(example_batch["pixel_values"], example_batch["label"])
    ]
    augmented_images = [item["image"] for item in augmented]
    augmented_labels = [item["mask"] for item in augmented]
    inputs = image_processor(augmented_images, augmented_labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

6. 從檢查點初始化模型

我們將使用預訓練的Segformer模型,檢查點為:nvidia/mit-b0。此架構在論文SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers中有詳細介紹,並在ImageNet-1k上進行了訓練。

from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name, id2label=id2label, label2id=label2id)

7. 設定訓練引數並連線到Weights & Biases 📉

接下來,我們將配置訓練引數並連線到Weights & Biases (W&B)。W&B將幫助我們跟蹤實驗、視覺化指標並管理模型訓練工作流程,從而在整個過程中提供有價值的見解。

from transformers import TrainingArguments

output_dir = "test-segformer-b0-segments-sidewalk-finetuned"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=6e-5,
    num_train_epochs=20,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    report_to="wandb",
)
import wandb

wandb.init(
    project="test-segformer-b0-segments-sidewalk-finetuned",  # change this
    name="test-segformer-b0-segments-sidewalk-finetuned",  # change this
    config=training_args,
)

8. 設定自定義compute_metrics方法以透過evaluate進行增強日誌記錄

我們將使用平均交併比(mean IoU)作為評估模型效能的主要指標。這將使我們能夠詳細跟蹤每個類別的效能。

此外,我們將調整評估模組的日誌級別,以儘量減少輸出中的警告。如果影像中未檢測到某個類別,您可能會看到如下警告:

RuntimeWarning: invalid value encountered in divide iou = total_area_intersect / total_area_union

如果您希望看到這些警告並繼續下一步,可以跳過此單元格。

import evaluate

evaluate.logging.set_verbosity_error()
import torch
from torch import nn
import multiprocessing

metric = evaluate.load("mean_iou")


def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        # currently using _compute instead of compute: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        pred_labels = logits_tensor.detach().cpu().numpy()
        import warnings

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

9. 在我們的資料集上訓練模型 🏋

現在是時候在我們的自定義資料集上訓練模型了。我們將使用準備好的訓練引數和連線的Weights & Biases整合來監控訓練過程並根據需要進行調整。讓我們開始訓練並觀察模型效能的提高!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)
trainer.train()

10. 評估模型在新影像上的效能 📸

訓練後,我們將評估模型在新影像上的效能。我們將使用測試影像並利用pipeline來評估模型在未見過的資料上的表現。

import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw

url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"

image = Image.open(requests.get(url, stream=True).raw)

image_segmentator = pipeline(
    "image-segmentation",
    model="sergiopaniego/test-segformer-b0-segments-sidewalk-finetuned",  # Change with your model name
)

results = image_segmentator(image)
>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

模型已經生成了一些掩碼,因此我們可以將其視覺化以評估和理解其效能。這將幫助我們瞭解模型分割影像的效果,並找出需要改進的領域。

>>> image_array = np.array(image)

>>> segmentation_map = np.zeros_like(image_array)

>>> for result in results:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

11. 評估測試集上的效能 📊

>>> metrics = trainer.evaluate(test_ds)
>>> print(metrics)
{'eval_loss': 0.6063494086265564, 'eval_mean_iou': 0.26682655949637757, 'eval_mean_accuracy': 0.3233445959272099, 'eval_overall_accuracy': 0.834762670692357, 'eval_accuracy_unlabeled': nan, 'eval_accuracy_flat-road': 0.8794976463015708, 'eval_accuracy_flat-sidewalk': 0.9287807675111692, 'eval_accuracy_flat-crosswalk': 0.5247038032656313, 'eval_accuracy_flat-cyclinglane': 0.795399495199148, 'eval_accuracy_flat-parkingdriveway': 0.4010852199852775, 'eval_accuracy_flat-railtrack': nan, 'eval_accuracy_flat-curb': 0.4902816930389514, 'eval_accuracy_human-person': 0.5913439011934908, 'eval_accuracy_human-rider': 0.0, 'eval_accuracy_vehicle-car': 0.9253204043875328, 'eval_accuracy_vehicle-truck': 0.0, 'eval_accuracy_vehicle-bus': 0.0, 'eval_accuracy_vehicle-tramtrain': 0.0, 'eval_accuracy_vehicle-motorcycle': 0.0, 'eval_accuracy_vehicle-bicycle': 0.0013499147866290941, 'eval_accuracy_vehicle-caravan': 0.0, 'eval_accuracy_vehicle-cartrailer': 0.0, 'eval_accuracy_construction-building': 0.8815560533904696, 'eval_accuracy_construction-door': 0.0, 'eval_accuracy_construction-wall': 0.4455930603622635, 'eval_accuracy_construction-fenceguardrail': 0.3431640802292688, 'eval_accuracy_construction-bridge': 0.0, 'eval_accuracy_construction-tunnel': nan, 'eval_accuracy_construction-stairs': 0.0, 'eval_accuracy_object-pole': 0.24341265579591848, 'eval_accuracy_object-trafficsign': 0.0, 'eval_accuracy_object-trafficlight': 0.0, 'eval_accuracy_nature-vegetation': 0.9478392425169023, 'eval_accuracy_nature-terrain': 0.8560970005175594, 'eval_accuracy_sky': 0.9530036096232858, 'eval_accuracy_void-ground': 0.0, 'eval_accuracy_void-dynamic': 0.0, 'eval_accuracy_void-static': 0.13859852156564748, 'eval_accuracy_void-unclear': 0.0, 'eval_iou_unlabeled': nan, 'eval_iou_flat-road': 0.7270368663334998, 'eval_iou_flat-sidewalk': 0.8484429155310914, 'eval_iou_flat-crosswalk': 0.3716762279636531, 'eval_iou_flat-cyclinglane': 0.6983685965068486, 'eval_iou_flat-parkingdriveway': 0.3073600964845036, 'eval_iou_flat-railtrack': nan, 'eval_iou_flat-curb': 0.3781660047058077, 'eval_iou_human-person': 0.38559031115261033, 'eval_iou_human-rider': 0.0, 'eval_iou_vehicle-car': 0.7473290757373612, 'eval_iou_vehicle-truck': 0.0, 'eval_iou_vehicle-bus': 0.0, 'eval_iou_vehicle-tramtrain': 0.0, 'eval_iou_vehicle-motorcycle': 0.0, 'eval_iou_vehicle-bicycle': 0.0013499147866290941, 'eval_iou_vehicle-caravan': 0.0, 'eval_iou_vehicle-cartrailer': 0.0, 'eval_iou_construction-building': 0.6637240016649857, 'eval_iou_construction-door': 0.0, 'eval_iou_construction-wall': 0.3336225132267832, 'eval_iou_construction-fenceguardrail': 0.3131070176565442, 'eval_iou_construction-bridge': 0.0, 'eval_iou_construction-tunnel': nan, 'eval_iou_construction-stairs': 0.0, 'eval_iou_object-pole': 0.17741310577170807, 'eval_iou_object-trafficsign': 0.0, 'eval_iou_object-trafficlight': 0.0, 'eval_iou_nature-vegetation': 0.837720086429597, 'eval_iou_nature-terrain': 0.7272281817316115, 'eval_iou_sky': 0.9005169994943569, 'eval_iou_void-ground': 0.0, 'eval_iou_void-dynamic': 0.0, 'eval_iou_void-static': 0.11979798870649179, 'eval_iou_void-unclear': 0.0, 'eval_runtime': 30.5276, 'eval_samples_per_second': 6.551, 'eval_steps_per_second': 0.819, 'epoch': 20.0}

12. 使用推理API訪問模型並可視化結果 🔌

Hugging Face 🤗 提供了一個無伺服器推理API,允許您透過API端點免費直接測試模型。有關使用此API的詳細指南,請查閱此食譜

我們將使用此API來探索其功能並檢視如何將其用於測試我們的模型。

重要提示

在使用無伺服器推理API之前,您需要透過建立模型卡來設定模型任務。在為您微調的模型建立模型卡時,請確保您適當地指定了任務。

image.png

模型任務設定完成後,我們可以下載影像並使用InferenceClient測試模型。此客戶端將允許我們透過API將影像傳送到模型並檢索結果進行評估。

>>> url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

我們將使用InferenceClient的image_segmentation方法。此方法將模型和影像作為輸入,並返回預測的掩碼。這將使我們能夠測試模型在新影像上的表現。

from huggingface_hub import InferenceClient

client = InferenceClient()

response = client.image_segmentation(
    model="sergiopaniego/test-segformer-b0-segments-sidewalk-finetuned",  # Change with your model name
    image="https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)

print(response)

有了預測的掩碼,我們就可以顯示結果了。

>>> image_array = np.array(image)
>>> segmentation_map = np.zeros_like(image_array)

>>> for result in response:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

也可以使用JavaScript與推理API互動。以下是您如何使用JavaScript呼叫API的示例:

import { HfInference } from "@huggingface/inference";

const inference = new HfInference(HF_TOKEN);
await inference.imageSegmentation({
    data: await (await fetch("https://picsum.photos/300/300")).blob(),
    model: "sergiopaniego/segformer-b0-segments-sidewalk-finetuned",
});

額外提示

您還可以使用Hugging Face Space部署微調後的模型。例如,我建立了一個自定義Space來展示這一點:使用SegFormer在Segments/Sidewalk上微調的語義分割

HF Spaces logo
from IPython.display import IFrame

IFrame(src="https://sergiopaniego-segformer-b0-segments-sidewalk-finetuned.hf.space", width=1000, height=800)

結論

在本指南中,我們成功地在自定義資料集上微調了一個語義分割模型,並利用無伺服器推理API對其進行了測試。這演示了您可以多麼輕鬆地將模型整合到各種應用程式中,並利用Hugging Face工具進行部署。

我希望本指南能為您提供自信地微調和部署您自己的模型所需的工具和知識!🚀

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.