Diffusers 文件

UNet2DConditionModel

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

UNet2DConditionModel

UNet 模型最初由 Ronneberger 等人引入,用於生物醫學影像分割,但它也常用於 🤗 Diffusers,因為它輸出的影像大小與輸入相同。它是擴散系統最重要的元件之一,因為它促進了實際的擴散過程。在 🤗 Diffusers 中,UNet 模型有多種變體,具體取決於其維度數量以及它是否是條件模型。這是一個 2D UNet 條件模型。

論文摘要如下:

人們普遍認為,深度網路的成功訓練需要數千個帶註釋的訓練樣本。在本文中,我們提出了一種網路和訓練策略,它強烈依賴資料增強,以更有效地利用可用的帶註釋樣本。該架構包括一個收縮路徑用於捕獲上下文,以及一個對稱的擴充套件路徑,可實現精確的定位。我們展示了這樣的網路可以從很少的影像端到端訓練,並且在 ISBI 挑戰賽中,在電子顯微鏡堆疊中分割神經元結構的任務上,其效能優於先前的最佳方法(滑動視窗卷積網路)。使用在透射光顯微鏡影像(相差和 DIC)上訓練的相同網路,我們在 2015 年 ISBI 細胞追蹤挑戰賽的這些類別中以大幅優勢獲勝。此外,該網路速度很快。在最新的 GPU 上分割 512x512 影像所需時間不到一秒。完整實現(基於 Caffe)和訓練好的網路可在 http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net 獲取。

UNet2DConditionModel

diffusers.UNet2DConditionModel

< >

( sample_size: typing.Union[int, typing.Tuple[int, int], NoneType] = None in_channels: int = 4 out_channels: int = 4 center_input_sample: bool = False flip_sin_to_cos: bool = True freq_shift: int = 0 down_block_types: typing.Tuple[str] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' up_block_types: typing.Tuple[str] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int] = (320, 640, 1280, 1280) layers_per_block: typing.Union[int, typing.Tuple[int]] = 2 downsample_padding: int = 1 mid_block_scale_factor: float = 1 dropout: float = 0.0 act_fn: str = 'silu' norm_num_groups: typing.Optional[int] = 32 norm_eps: float = 1e-05 cross_attention_dim: typing.Union[int, typing.Tuple[int]] = 1280 transformer_layers_per_block: typing.Union[int, typing.Tuple[int], typing.Tuple[typing.Tuple]] = 1 reverse_transformer_layers_per_block: typing.Optional[typing.Tuple[typing.Tuple[int]]] = None encoder_hid_dim: typing.Optional[int] = None encoder_hid_dim_type: typing.Optional[str] = None attention_head_dim: typing.Union[int, typing.Tuple[int]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int], NoneType] = None dual_cross_attention: bool = False use_linear_projection: bool = False class_embed_type: typing.Optional[str] = None addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None num_class_embeds: typing.Optional[int] = None upcast_attention: bool = False resnet_time_scale_shift: str = 'default' resnet_skip_time_act: bool = False resnet_out_scale_factor: float = 1.0 time_embedding_type: str = 'positional' time_embedding_dim: typing.Optional[int] = None time_embedding_act_fn: typing.Optional[str] = None timestep_post_act: typing.Optional[str] = None time_cond_proj_dim: typing.Optional[int] = None conv_in_kernel: int = 3 conv_out_kernel: int = 3 projection_class_embeddings_input_dim: typing.Optional[int] = None attention_type: str = 'default' class_embeddings_concat: bool = False mid_block_only_cross_attention: typing.Optional[bool] = None cross_attention_norm: typing.Optional[str] = None addition_embed_type_num_heads: int = 64 )

引數

  • sample_size (intTuple[int, int], 可選, 預設為 None) — 輸入/輸出樣本的高度和寬度。
  • in_channels (int, 可選, 預設為 4) — 輸入樣本中的通道數。
  • out_channels (int, 可選, 預設為 4) — 輸出中的通道數。
  • center_input_sample (bool, 可選, 預設為 False) — 是否對輸入樣本進行居中。
  • flip_sin_to_cos (bool, 可選, 預設為 True) — 是否在時間嵌入中將 sin 翻轉為 cos。
  • freq_shift (int, 可選, 預設為 0) — 應用於時間嵌入的頻率偏移。
  • down_block_types (Tuple[str], 可選, 預設為 ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")) — 要使用的下采樣塊的元組。
  • mid_block_type (str, 可選, 預設為 "UNetMidBlock2DCrossAttn") — UNet 中間塊的塊型別,可以是 UNetMidBlock2DCrossAttn, UNetMidBlock2D, 或 UNetMidBlock2DSimpleCrossAttn 之一。如果為 None,則跳過中間塊層。
  • up_block_types (Tuple[str], 可選, 預設為 ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")) — 要使用的上取樣塊的元組。
  • only_cross_attention(boolTuple[bool], 可選, 預設為 False) — 是否在基本轉換器塊中包含自注意力,請參閱 BasicTransformerBlock
  • block_out_channels (Tuple[int], 可選, 預設為 (320, 640, 1280, 1280)) — 每個塊的輸出通道元組。
  • layers_per_block (int, 可選, 預設為 2) — 每個塊的層數。
  • downsample_padding (int, 可選, 預設為 1) — 用於下采樣卷積的填充。
  • mid_block_scale_factor (float, 可選, 預設為 1.0) — 用於中間塊的比例因子。
  • dropout (float, 可選, 預設為 0.0) — 要使用的 dropout 機率。
  • act_fn (str, 可選, 預設為 "silu") — 要使用的啟用函式。
  • norm_num_groups (int, 可選, 預設為 32) — 用於歸一化的組數。如果為 None,則在後處理中跳過歸一化和啟用層。
  • norm_eps (float, 可選, 預設為 1e-5) — 用於歸一化的 epsilon。
  • cross_attention_dim (intTuple[int], 可選, 預設為 1280) — 交叉注意力特徵的維度。
  • transformer_layers_per_block (int, Tuple[int], 或 Tuple[Tuple], 可選, 預設為 1) — BasicTransformerBlock 型別的轉換器塊的數量。僅與 CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn 相關。
  • reverse_transformer_layers_per_block — (Tuple[Tuple], 可選, 預設為 None): U-Net 上取樣塊中 BasicTransformerBlock 型別的轉換器塊的數量。僅當 transformer_layers_per_blockTuple[Tuple] 型別時,且對於 CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn 相關。
  • encoder_hid_dim (int, 可選, 預設為 None) — 如果定義了 encoder_hid_dim_type,則 encoder_hidden_states 將從 encoder_hid_dim 維度投射到 cross_attention_dim
  • encoder_hid_dim_type (str, 可選, 預設為 None) — 如果給定,encoder_hidden_states 和可能其他嵌入將根據 encoder_hid_dim_type 降維投射到 cross_attention 維度的文字嵌入。
  • attention_head_dim (int, 可選, 預設為 8) — 注意力頭的維度。
  • num_attention_heads (int, 可選) — 注意力頭的數量。如果未定義,則預設為 attention_head_dim
  • resnet_time_scale_shift (str, 可選, 預設為 "default") — ResNet 塊的時間尺度偏移配置(參見 ResnetBlock2D)。可選擇 defaultscale_shift
  • class_embed_type (str, 可選, 預設為 None) — 類嵌入的型別,最終會與時間嵌入相加。可選擇 None, "timestep", "identity", "projection", 或 "simple_projection"
  • addition_embed_type (str, 可選, 預設為 None) — 配置一個可選的嵌入,該嵌入將與時間嵌入相加。可選擇 None 或 “text”。“text” 將使用 TextTimeEmbedding 層。
  • addition_time_embed_dim — (int, 可選, 預設為 None): 時間步嵌入的維度。
  • num_class_embeds (int, 可選, 預設為 None) — 可學習嵌入矩陣的輸入維度,當使用 class_embed_typeNone 進行類別條件化時,該矩陣將投射到 time_embed_dim
  • time_embedding_type (str, 可選, 預設為 positional) — 用於時間步長的位置嵌入型別。可選擇 positionalfourier
  • time_embedding_dim (int, 可選, 預設為 None) — 投影時間嵌入的可選維度覆蓋。
  • time_embedding_act_fn (str, 可選, 預設為 None) — 在時間嵌入傳遞給 UNet 的其餘部分之前,僅使用一次的可選啟用函式。可選擇 silumishgeluswish
  • timestep_post_act (str, 可選, 預設為 None) — 在時間步長嵌入中使用的第二個啟用函式。可選擇 silumishgelu
  • time_cond_proj_dim (int, 可選, 預設為 None) — 時間步長嵌入中 cond_proj 層的維度。
  • conv_in_kernel (int, 可選, 預設為 3) — conv_in 層的核大小。
  • conv_out_kernel (int, 可選, 預設為 3) — conv_out 層的核大小。
  • projection_class_embeddings_input_dim (int, 可選) — 當 class_embed_type="projection" 時,class_labels 輸入的維度。當 class_embed_type="projection" 時必需。
  • class_embeddings_concat (bool, 可選, 預設為 False) — 是否將時間嵌入與類別嵌入拼接。
  • mid_block_only_cross_attention (bool, 可選, 預設為 None) — 在使用 UNetMidBlock2DSimpleCrossAttn 時,是否使用帶有中間塊的交叉注意力。如果 only_cross_attention 給定為單個布林值且 mid_block_only_cross_attentionNone,則 only_cross_attention 的值將用作 mid_block_only_cross_attention 的值。否則預設為 False

一個條件 2D UNet 模型,接收噪聲樣本、條件狀態和時間步長,並返回一個樣本形狀的輸出。

此模型繼承自 ModelMixin。有關所有模型實現的通用方法(如下載或儲存),請參閱超類文件。

停用 freeu

< >

( )

停用 FreeU 機制。

啟用 freeu

< >

( s1: float s2: float b1: float b2: float )

引數

  • s1 (float) — 階段 1 的縮放因子,用於衰減跳過特徵的貢獻。這樣做是為了減輕增強去噪過程中的“過度平滑效應”。
  • s2 (float) — 階段 2 的縮放因子,用於衰減跳過特徵的貢獻。這樣做是為了減輕增強去噪過程中的“過度平滑效應”。
  • b1 (float) — 階段 1 的縮放因子,用於放大主幹特徵的貢獻。
  • b2 (float) — 階段 2 的縮放因子,用於放大主幹特徵的貢獻。

啟用來自 https://huggingface.co/papers/2309.11497 的 FreeU 機制。

縮放因子後面的字尾表示它們正在應用的階段塊。

請參閱官方倉庫,瞭解適用於 Stable Diffusion v1、v2 和 Stable Diffusion XL 等不同管道的已知良好值組合。

前向傳播

< >

( sample: Tensor timestep: typing.Union[torch.Tensor, float, int] encoder_hidden_states: Tensor class_labels: typing.Optional[torch.Tensor] = None timestep_cond: typing.Optional[torch.Tensor] = None attention_mask: typing.Optional[torch.Tensor] = None cross_attention_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None added_cond_kwargs: typing.Optional[typing.Dict[str, torch.Tensor]] = None down_block_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None mid_block_additional_residual: typing.Optional[torch.Tensor] = None down_intrablock_additional_residuals: typing.Optional[typing.Tuple[torch.Tensor]] = None encoder_attention_mask: typing.Optional[torch.Tensor] = None return_dict: bool = True ) UNet2DConditionOutputtuple

引數

  • sample (torch.Tensor) — 形狀為 (batch, channel, height, width) 的帶噪輸入張量。
  • timestep (torch.Tensorfloatint) — 用於去噪輸入的時間步長數量。
  • encoder_hidden_states (torch.Tensor) — 形狀為 (batch, sequence_length, feature_dim) 的編碼器隱藏狀態。
  • class_labels (torch.Tensor, 可選, 預設為 None) — 用於條件作用的可選類別標籤。它們的嵌入將與時間步長嵌入求和。
  • timestep_cond — (torch.Tensor, 可選, 預設為 None):時間步長的條件嵌入。如果提供,嵌入將與透過 self.time_embedding 層傳遞的樣本求和,以獲得時間步長嵌入。
  • attention_mask (torch.Tensor, 可選, 預設為 None) — 形狀為 (batch, key_tokens) 的注意力掩碼應用於 encoder_hidden_states。如果為 1,則保留掩碼,否則為 0 則丟棄。掩碼將被轉換為偏差,這將為與“丟棄”標記對應的注意力分數新增較大的負值。
  • cross_attention_kwargs (dict, 可選) — 如果指定,將作為 kwargs 字典傳遞給 AttentionProcessor,其定義在 diffusers.models.attention_processor 中的 self.processor 下。
  • added_cond_kwargs — (dict, 可選):一個 kwargs 字典,如果指定,其中包含的額外嵌入將新增到傳遞給 UNet 塊的嵌入中。
  • down_block_additional_residuals — (tuple of torch.Tensor, 可選):如果指定,將新增到 UNet 下行塊殘差的張量元組。
  • mid_block_additional_residual — (torch.Tensor, 可選):如果指定,將新增到中間 UNet 塊殘差的張量。
  • down_intrablock_additional_residuals (tuple of torch.Tensor, 可選) — 要新增到 UNet 下行塊內的額外殘差,例如來自 T2I-Adapter 側模型的殘差。
  • encoder_attention_mask (torch.Tensor) — 形狀為 (batch, sequence_length) 的交叉注意力掩碼應用於 encoder_hidden_states。如果為 True,則保留掩碼,否則為 False 則丟棄。掩碼將被轉換為偏差,這將為與“丟棄”標記對應的注意力分數新增較大的負值。
  • return_dict (bool, 可選, 預設為 True) — 是否返回 UNet2DConditionOutput 而不是普通元組。

返回

UNet2DConditionOutputtuple

如果 return_dict 為 True,則返回 UNet2DConditionOutput,否則返回 tuple,其中第一個元素是樣本張量。

UNet2DConditionModel 前向傳播方法。

融合 qkv 投影

< >

( )

啟用融合 QKV 投影。對於自注意力模組,所有投影矩陣(即查詢、鍵、值)都將融合。對於交叉注意力模組,鍵和值投影矩陣將融合。

此 API 是 🧪 實驗性的。

設定注意力切片

< >

( slice_size: typing.Union[str, int, typing.List[int]] = 'auto' )

引數

  • slice_size (strintlist(int), 可選, 預設為 "auto") — 當為 "auto" 時,輸入到注意力頭的資料減半,因此注意力分兩步計算。如果為 "max",則透過每次只執行一個切片來節省最大記憶體。如果提供數字,則使用 attention_head_dim // slice_size 個切片。在這種情況下,attention_head_dim 必須是 slice_size 的倍數。

啟用分片注意力計算。

啟用此選項後,注意力模組會將輸入張量分片以分步計算注意力。這對於節省記憶體非常有用,但會稍微降低速度。

設定注意力處理器

< >

( processor: typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor, typing.Dict[str, typing.Union[diffusers.models.attention_processor.AttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor, diffusers.models.attention_processor.AttnAddedKVProcessor2_0, diffusers.models.attention_processor.JointAttnProcessor2_0, diffusers.models.attention_processor.PAGJointAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGJointAttnProcessor2_0, diffusers.models.attention_processor.FusedJointAttnProcessor2_0, diffusers.models.attention_processor.AllegroAttnProcessor2_0, diffusers.models.attention_processor.AuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FusedAuraFlowAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0, diffusers.models.attention_processor.FluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0, diffusers.models.attention_processor.FusedFluxAttnProcessor2_0_NPU, diffusers.models.attention_processor.CogVideoXAttnProcessor2_0, diffusers.models.attention_processor.FusedCogVideoXAttnProcessor2_0, diffusers.models.attention_processor.XFormersAttnAddedKVProcessor, diffusers.models.attention_processor.XFormersAttnProcessor, diffusers.models.attention_processor.XLAFlashAttnProcessor2_0, diffusers.models.attention_processor.AttnProcessorNPU, diffusers.models.attention_processor.AttnProcessor2_0, diffusers.models.attention_processor.MochiVaeAttnProcessor2_0, diffusers.models.attention_processor.MochiAttnProcessor2_0, diffusers.models.attention_processor.StableAudioAttnProcessor2_0, diffusers.models.attention_processor.HunyuanAttnProcessor2_0, diffusers.models.attention_processor.FusedHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGHunyuanAttnProcessor2_0, diffusers.models.attention_processor.LuminaAttnProcessor2_0, diffusers.models.attention_processor.FusedAttnProcessor2_0, diffusers.models.attention_processor.CustomDiffusionXFormersAttnProcessor, diffusers.models.attention_processor.CustomDiffusionAttnProcessor2_0, diffusers.models.attention_processor.SlicedAttnProcessor, diffusers.models.attention_processor.SlicedAttnAddedKVProcessor, diffusers.models.attention_processor.SanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleLinearAttention, diffusers.models.attention_processor.SanaMultiscaleAttnProcessor2_0, diffusers.models.attention_processor.SanaMultiscaleAttentionProjection, diffusers.models.attention_processor.IPAdapterAttnProcessor, diffusers.models.attention_processor.IPAdapterAttnProcessor2_0, diffusers.models.attention_processor.IPAdapterXFormersAttnProcessor, diffusers.models.attention_processor.SD3IPAdapterJointAttnProcessor2_0, diffusers.models.attention_processor.PAGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0, diffusers.models.attention_processor.LoRAAttnProcessor, diffusers.models.attention_processor.LoRAAttnProcessor2_0, diffusers.models.attention_processor.LoRAXFormersAttnProcessor, diffusers.models.attention_processor.LoRAAttnAddedKVProcessor]]] )

引數

  • processor (AttentionProcessor 字典或僅 AttentionProcessor) — 將設定為 所有 Attention 層的處理器的例項化處理器類或處理器類字典。

    如果 processor 是一個字典,則鍵需要定義到相應交叉注意力處理器的路徑。強烈建議在設定可訓練注意力處理器時這樣做。

設定用於計算注意力的注意力處理器。

設定預設注意力處理器

< >

( )

停用自定義注意力處理器並設定預設注意力實現。

取消融合 QKV 投影

< >

( )

如果啟用了,則停用融合的 QKV 投影。

此 API 是 🧪 實驗性的。

UNet2DConditionOutput

class diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput

< >

( sample: Tensor = None )

引數

  • sample (torch.Tensor, 形狀為 (batch_size, num_channels, height, width)) — 在 encoder_hidden_states 輸入條件下輸出的隱藏狀態。模型的最後一層輸出。

UNet2DConditionModel 的輸出。

FlaxUNet2DConditionModel

class diffusers.FlaxUNet2DConditionModel

< >

( sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 down_block_types: typing.Tuple[str, ...] = ('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D') up_block_types: typing.Tuple[str, ...] = ('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D') mid_block_type: typing.Optional[str] = 'UNetMidBlock2DCrossAttn' only_cross_attention: typing.Union[bool, typing.Tuple[bool]] = False block_out_channels: typing.Tuple[int, ...] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: typing.Union[int, typing.Tuple[int, ...]] = 8 num_attention_heads: typing.Union[int, typing.Tuple[int, ...], NoneType] = None cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False dtype: dtype = <class 'jax.numpy.float32'> flip_sin_to_cos: bool = True freq_shift: int = 0 use_memory_efficient_attention: bool = False split_head_dim: bool = False transformer_layers_per_block: typing.Union[int, typing.Tuple[int, ...]] = 1 addition_embed_type: typing.Optional[str] = None addition_time_embed_dim: typing.Optional[int] = None addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: typing.Optional[int] = None parent: typing.Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object at 0x7fc460aac610> name: typing.Optional[str] = None )

引數

  • sample_size (int, 可選) — 輸入樣本的大小。
  • in_channels (int, 可選, 預設為 4) — 輸入樣本中的通道數。
  • out_channels (int, 可選, 預設為 4) — 輸出中的通道數。
  • down_block_types (Tuple[str], 可選, 預設為 ("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")) — 要使用的下采樣塊的元組。
  • up_block_types (Tuple[str], 可選, 預設為 ("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")) — 要使用的上取樣塊的元組。
  • mid_block_type (str, 可選, 預設為 "UNetMidBlock2DCrossAttn") — UNet 中間塊的塊型別,可以是 UNetMidBlock2DCrossAttn 之一。如果為 None,則跳過中間塊層。
  • block_out_channels (Tuple[int], 可選, 預設為 (320, 640, 1280, 1280)) — 每個塊的輸出通道元組。
  • layers_per_block (int, 可選, 預設為 2) — 每個塊的層數。
  • attention_head_dim (intTuple[int], 可選, 預設為 8) — 注意力頭的維度。
  • num_attention_heads (intTuple[int], 可選) — 注意力頭的數量。
  • cross_attention_dim (int, 可選, 預設為 768) — 交叉注意力特徵的維度。
  • dropout (float, 可選, 預設為 0) — 下采樣、上取樣和瓶頸塊的 dropout 機率。
  • flip_sin_to_cos (bool, 可選, 預設為 True) — 是否在時間嵌入中將 sin 翻轉為 cos。
  • freq_shift (int, 可選, 預設為 0) — 應用於時間嵌入的頻率偏移。
  • use_memory_efficient_attention (bool, 可選, 預設為 False) — 啟用 此處 描述的記憶體高效注意力。
  • split_head_dim (bool, 可選, 預設為 False) — 是否將頭部維度拆分為自注意力計算的新軸。在大多數情況下,啟用此標誌應能加速 Stable Diffusion 2.x 和 Stable Diffusion XL 的計算。

一個條件 2D UNet 模型,接收噪聲樣本、條件狀態和時間步長,並返回一個樣本形狀的輸出。

此模型繼承自 FlaxModelMixin。請檢視超類文件以瞭解所有模型實現的通用方法(例如下載或儲存)。

此模型也是 Flax Linen flax.linen.Module 的子類。將其作為常規 Flax Linen 模組使用,並參閱 Flax 文件中與其一般用法和行為相關的所有內容。

支援以下固有的 JAX 功能

FlaxUNet2DConditionOutput

class diffusers.models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput

< >

( sample: Array )

引數

  • sample (形狀為 (batch_size, num_channels, height, width)jnp.ndarray) — 以 encoder_hidden_states 輸入為條件的隱藏狀態輸出。模型最後一層的輸出。

FlaxUNet2DConditionModel 的輸出。

替換

< >

( **updates )

返回一個新物件,用新值替換指定的欄位。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.