Transformers 文件
ShieldGemma 2
並獲得增強的文件體驗
開始使用
ShieldGemma 2
概述
ShieldGemma 2 模型由 Google 在一份技術報告中提出。ShieldGemma 2 基於 Gemma 3 構建,是一個擁有 40 億(4B)引數的模型,可以根據關鍵類別檢查合成影像和自然影像的安全性,以幫助您構建強大的資料集和模型。透過將此模型新增到 Gemma 模型系列,研究人員和開發人員現在可以輕鬆地將模型中有害內容的風險降至最低,具體定義如下:
- 無露骨色情內容:影像不得包含描繪露骨或圖形化性行為的內容(例如,色情、色情裸露、描繪強姦或性侵犯)。
- 無危險內容:影像不得包含促使或鼓勵可能造成現實世界傷害的活動的內容(例如,製造槍支和爆炸裝置、宣揚恐怖主義、自殺說明)。
- 無暴力/血腥內容:影像不得包含描繪令人震驚、聳人聽聞或無謂暴力的內容(例如,過度血腥、對動物的無謂暴力、極端傷害或死亡瞬間)。
我們建議將 ShieldGemma 2 用作視覺語言模型的輸入過濾器,或影像生成系統的輸出過濾器。為了訓練一個強大的影像安全模型,我們策劃了自然影像和合成影像的訓練資料集,並對 Gemma 3 進行了指令微調,以展示強大的效能。
該模型由Ryan Mullins貢獻。
使用示例
- ShieldGemma 2 提供了一個處理器 (Processor),它接受一個
images列表和一個可選的policies列表作為輸入,並使用提供的聊天模板將這兩個列表的乘積構建成一批提示。 - 您可以透過處理器的
custom_policies引數擴充套件 ShieldGemma 的內建策略。使用與內建策略相同的鍵將用您的自定義定義覆蓋該策略。 - ShieldGemma 2 不支援 Gemma 3 使用的影像裁剪功能。
基於內建策略的分類
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=[image], return_tensors="pt").to(model.device)
output = model(**inputs)
print(output.probabilities)基於自定義策略的分類
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
custom_policies = {
"key_a": "descrition_a",
"key_b": "descrition_b",
}
inputs = processor(
images=[image],
custom_policies=custom_policies,
policies=["dangerous", "key_a", "key_b"],
return_tensors="pt",
).to(model.device)
output = model(**inputs)
print(output.probabilities)ShieldGemma2Processor
class transformers.ShieldGemma2Processor
< source >( image_processor tokenizer chat_template = None image_seq_length = 256 policy_definitions = None **kwargs )
ShieldGemma2Config
class transformers.ShieldGemma2Config
< source >( text_config = None vision_config = None mm_tokens_per_image: int = 256 boi_token_index: int = 255999 eoi_token_index: int = 256000 image_token_index: int = 262144 initializer_range: float = 0.02 **kwargs )
引數
- text_config (
Union[ShieldGemma2TextConfig, dict], 可選) — 文字骨幹的配置物件。 - vision_config (
Union[AutoConfig, dict], 可選) — 自定義視覺配置或字典。 - mm_tokens_per_image (
int, 可選, 預設為 256) — 每個影像嵌入的標記數量。 - boi_token_index (
int, 可選, 預設為 255999) — 用於包裹影像提示的影像起始標記索引。 - eoi_token_index (
int, 可選, 預設為 256000) — 用於包裹影像提示的影像結束標記索引。 - image_token_index (
int, 可選, 預設為 262144) — 用於編碼影像提示的影像標記索引。 - initializer_range (
float, 可選, 預設為 0.02) — 用於初始化所有權重矩陣的截斷正態初始化器的標準差。
這是儲存 ShieldGemma2ForImageClassification 配置的配置類。它用於根據指定引數例項化 ShieldGemma2ForImageClassification,定義模型架構。使用預設值例項化配置將產生與 shieldgemma-2-4b-it 相似的配置。
配置物件繼承自 PretrainedConfig,可用於控制模型輸出。有關更多資訊,請參閱 PretrainedConfig 的文件。
示例
>>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a ShieldGemma2 Text config
>>> text_config = ShieldGemma2TextConfig()
>>> # Initializing a ShieldGemma2 gemma-3-4b style configuration
>>> configuration = ShieldGemma2Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = ShieldGemma2TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.configShieldGemma2ForImageClassification
class transformers.ShieldGemma2ForImageClassification
< source >( config: ShieldGemma2Config )
引數
- config (ShieldGemma2Config) — 包含模型所有引數的模型配置類。使用配置檔案初始化不會載入與模型相關的權重,只加載配置。請檢視 from_pretrained() 方法以載入模型權重。
Shieldgemma2 模型,其頂部帶有一個影像分類頭,例如用於 ImageNet。
此模型繼承自 PreTrainedModel。請檢視超類文件,瞭解庫為其所有模型實現的通用方法(例如下載或儲存、調整輸入嵌入大小、修剪頭部等)。
此模型也是 PyTorch torch.nn.Module 子類。將其作為常規 PyTorch 模組使用,並參考 PyTorch 文件以瞭解與一般用法和行為相關的所有事項。
forward
< source >( input_ids: typing.Optional[torch.LongTensor] = None pixel_values: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.Tensor] = None position_ids: typing.Optional[torch.LongTensor] = None past_key_values: typing.Union[list[torch.FloatTensor], transformers.cache_utils.Cache, NoneType] = None token_type_ids: typing.Optional[torch.LongTensor] = None cache_position: typing.Optional[torch.LongTensor] = None inputs_embeds: typing.Optional[torch.FloatTensor] = None labels: typing.Optional[torch.LongTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None logits_to_keep: typing.Union[int, torch.Tensor] = 0 **lm_kwargs )
引數
- input_ids (形狀為
(batch_size, sequence_length)的torch.LongTensor,可選) — 詞彙表中輸入序列標記的索引。預設情況下將忽略填充。可以使用 AutoTokenizer 獲取索引。有關詳細資訊,請參閱 PreTrainedTokenizer.encode() 和 PreTrainedTokenizer.call()。
- pixel_values (形狀為
(batch_size, num_channels, image_size, image_size)的torch.FloatTensor,可選) — 對應於輸入影像的張量。畫素值可以使用{image_processor_class}獲取。有關詳細資訊,請參閱{image_processor_class}.__call__({processor_class}使用{image_processor_class}處理影像)。 - attention_mask (形狀為
(batch_size, sequence_length)的torch.Tensor,可選) — 遮罩,用於避免對填充標記索引執行注意力。遮罩值選擇在[0, 1]之間:- 1 表示**未被遮罩**的標記,
- 0 表示**被遮罩**的標記。
- position_ids (形狀為
(batch_size, sequence_length)的torch.LongTensor,可選) — 每個輸入序列標記在位置嵌入中的位置索引。選擇範圍為[0, config.n_positions - 1]。 - past_key_values (
Union[list[torch.FloatTensor], ~cache_utils.Cache, NoneType]) — 預先計算的隱藏狀態(自注意力塊和交叉注意力塊中的鍵和值),可用於加速順序解碼。這通常包括模型在解碼前期返回的past_key_values,當use_cache=True或config.use_cache=True時。允許兩種格式:
- Cache 例項,請參閱我們的 kv cache 指南;
- 長度為
config.n_layers的tuple(torch.FloatTensor)元組,每個元組包含 2 個形狀為(batch_size, num_heads, sequence_length, embed_size_per_head)的張量。這也稱為傳統快取格式。
模型將輸出與輸入相同的快取格式。如果未傳入
past_key_values,則將返回傳統快取格式。如果使用
past_key_values,使用者可以選擇僅輸入最後一個input_ids(那些沒有將其過去鍵值狀態提供給此模型的)形狀為(batch_size, 1),而不是所有input_ids形狀為(batch_size, sequence_length)。 - token_type_ids (形狀為
(batch_size, sequence_length)的torch.LongTensor,可選) — 分段標記索引,指示輸入的第一個和第二個部分。索引選擇在[0, 1]之間:- 0 對應於**句子 A** 標記,
- 1 對應於**句子 B** 標記。
- cache_position (形狀為
(sequence_length)的torch.LongTensor,可選) — 描繪輸入序列標記在序列中位置的索引。與position_ids不同,此張量不受填充影響。它用於在正確位置更新快取並推斷完整的序列長度。 - inputs_embeds (形狀為
(batch_size, sequence_length, hidden_size)的torch.FloatTensor,可選) — 可選地,您可以直接傳入嵌入表示,而不是傳入input_ids。如果您想對如何將input_ids索引轉換為關聯向量有比模型內部嵌入查詢矩陣更多的控制權,這將很有用。 - labels (形狀為
(batch_size, sequence_length)的torch.LongTensor,可選) — 用於計算遮罩語言建模損失的標籤。索引應為[0, ..., config.vocab_size]或 -100 (請參閱input_ids文件字串)。索引設定為-100的標記將被忽略(遮罩),損失僅針對標籤在[0, ..., config.vocab_size]中的標記計算。 - use_cache (
bool, 可選) — 如果設定為True,將返回past_key_values鍵值狀態,可用於加速解碼(參見past_key_values)。 - output_attentions (
bool, 可選) — 是否返回所有注意力層的注意力張量。有關更多詳細資訊,請參閱返回張量下的attentions。 - output_hidden_states (
bool, 可選) — 是否返回所有層的隱藏狀態。有關更多詳細資訊,請參閱返回張量下的hidden_states。 - return_dict (
bool, 可選) — 是否返回 ModelOutput 而不是純元組。 - logits_to_keep (
Union[int, torch.Tensor], 預設為0) — 如果是int,則計算最後logits_to_keep個標記的邏輯值。如果是0,則計算所有input_ids的邏輯值(特殊情況)。生成時只需要最後一個標記的邏輯值,僅計算該標記可以節省記憶體,這對於長序列或大詞彙量非常重要。如果是torch.Tensor,則必須是對應於序列長度維度中要保留的索引的一維張量。這在使用打包張量格式(批次和序列長度的單維度)時很有用。
ShieldGemma2ForImageClassification 的 forward 方法,覆蓋了 __call__ 特殊方法。
儘管前向傳播的配方需要在此函式中定義,但應在此之後呼叫 Module 例項,因為前者負責執行預處理和後處理步驟,而後者則默默地忽略它們。
示例
>>> from transformers import AutoImageProcessor, ShieldGemma2ForImageClassification
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> image_processor = AutoImageProcessor.from_pretrained("google/gemma-3-4b")
>>> model = ShieldGemma2ForImageClassification.from_pretrained("google/gemma-3-4b")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
...