|
@@ -0,0 +1,1382 @@
|
|
|
|
|
+# coding=utf-8
|
|
|
|
|
+# Copyright 2023 The HuggingFace Inc. team.
|
|
|
|
|
+#
|
|
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
|
|
+# You may obtain a copy of the License at
|
|
|
|
|
+#
|
|
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
+#
|
|
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
|
|
+# limitations under the License.
|
|
|
|
|
+""" Conversion script for the Stable Diffusion checkpoints."""
|
|
|
|
|
+
|
|
|
|
|
+import re
|
|
|
|
|
+from io import BytesIO
|
|
|
|
|
+from typing import Optional
|
|
|
|
|
+
|
|
|
|
|
+import requests
|
|
|
|
|
+import torch
|
|
|
|
|
+from transformers import (
|
|
|
|
|
+ AutoFeatureExtractor,
|
|
|
|
|
+ BertTokenizerFast,
|
|
|
|
|
+ CLIPImageProcessor,
|
|
|
|
|
+ CLIPTextModel,
|
|
|
|
|
+ CLIPTextModelWithProjection,
|
|
|
|
|
+ CLIPTokenizer,
|
|
|
|
|
+ CLIPVisionConfig,
|
|
|
|
|
+ CLIPVisionModelWithProjection,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+from diffusers.models import (
|
|
|
|
|
+ AutoencoderKL,
|
|
|
|
|
+ # ControlNetModel,
|
|
|
|
|
+ PriorTransformer,
|
|
|
|
|
+ UNet2DConditionModel,
|
|
|
|
|
+)
|
|
|
|
|
+from diffusers.schedulers import (
|
|
|
|
|
+ DDIMScheduler,
|
|
|
|
|
+ DDPMScheduler,
|
|
|
|
|
+ DPMSolverMultistepScheduler,
|
|
|
|
|
+ EulerAncestralDiscreteScheduler,
|
|
|
|
|
+ EulerDiscreteScheduler,
|
|
|
|
|
+ HeunDiscreteScheduler,
|
|
|
|
|
+ LMSDiscreteScheduler,
|
|
|
|
|
+ PNDMScheduler,
|
|
|
|
|
+ UnCLIPScheduler,
|
|
|
|
|
+)
|
|
|
|
|
+# from diffusers.utils import is_omegaconf_available, is_safetensors_available, logging
|
|
|
|
|
+from diffusers.utils.import_utils import BACKENDS_MAPPING
|
|
|
|
|
+# from diffusers.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
|
|
|
|
+# from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
|
|
|
|
+# from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
|
|
|
+# from .safety_checker import StableDiffusionSafetyChecker
|
|
|
|
|
+# from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def shave_segments(path, n_shave_prefix_segments=1):
|
|
|
|
|
+ """
|
|
|
|
|
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
|
|
|
|
|
+ """
|
|
|
|
|
+ if n_shave_prefix_segments >= 0:
|
|
|
|
|
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
|
|
|
|
|
+ else:
|
|
|
|
|
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|
|
|
|
+ """
|
|
|
|
|
+ Updates paths inside resnets to the new naming scheme (local renaming)
|
|
|
|
|
+ """
|
|
|
|
|
+ mapping = []
|
|
|
|
|
+ for old_item in old_list:
|
|
|
|
|
+ new_item = old_item.replace("in_layers.0", "norm1")
|
|
|
|
|
+ new_item = new_item.replace("in_layers.2", "conv1")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("out_layers.0", "norm2")
|
|
|
|
|
+ new_item = new_item.replace("out_layers.3", "conv2")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
|
|
|
|
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
|
+
|
|
|
|
|
+ mapping.append({"old": old_item, "new": new_item})
|
|
|
|
|
+
|
|
|
|
|
+ return mapping
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|
|
|
|
+ """
|
|
|
|
|
+ Updates paths inside resnets to the new naming scheme (local renaming)
|
|
|
|
|
+ """
|
|
|
|
|
+ mapping = []
|
|
|
|
|
+ for old_item in old_list:
|
|
|
|
|
+ new_item = old_item
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
|
|
|
|
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
|
+
|
|
|
|
|
+ mapping.append({"old": old_item, "new": new_item})
|
|
|
|
|
+
|
|
|
|
|
+ return mapping
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
|
|
|
|
+ """
|
|
|
|
|
+ Updates paths inside attentions to the new naming scheme (local renaming)
|
|
|
|
|
+ """
|
|
|
|
|
+ mapping = []
|
|
|
|
|
+ for old_item in old_list:
|
|
|
|
|
+ new_item = old_item
|
|
|
|
|
+
|
|
|
|
|
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
|
|
|
|
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
|
|
|
|
+
|
|
|
|
|
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
|
|
|
|
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
|
|
|
|
+
|
|
|
|
|
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
|
+
|
|
|
|
|
+ mapping.append({"old": old_item, "new": new_item})
|
|
|
|
|
+
|
|
|
|
|
+ return mapping
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
|
|
|
|
+ """
|
|
|
|
|
+ Updates paths inside attentions to the new naming scheme (local renaming)
|
|
|
|
|
+ """
|
|
|
|
|
+ mapping = []
|
|
|
|
|
+ for old_item in old_list:
|
|
|
|
|
+ new_item = old_item
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
|
|
|
|
|
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("q.weight", "query.weight")
|
|
|
|
|
+ new_item = new_item.replace("q.bias", "query.bias")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("k.weight", "key.weight")
|
|
|
|
|
+ new_item = new_item.replace("k.bias", "key.bias")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("v.weight", "value.weight")
|
|
|
|
|
+ new_item = new_item.replace("v.bias", "value.bias")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
|
|
|
|
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
|
|
|
|
+
|
|
|
|
|
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
|
|
+
|
|
|
|
|
+ mapping.append({"old": old_item, "new": new_item})
|
|
|
|
|
+
|
|
|
|
|
+ return mapping
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def assign_to_checkpoint(
|
|
|
|
|
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
|
|
|
|
+ attention layers, and takes into account additional replacements that may arise.
|
|
|
|
|
+
|
|
|
|
|
+ Assigns the weights to the new checkpoint.
|
|
|
|
|
+ """
|
|
|
|
|
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
|
|
|
|
+
|
|
|
|
|
+ # Splits the attention layers into three variables.
|
|
|
|
|
+ if attention_paths_to_split is not None:
|
|
|
|
|
+ for path, path_map in attention_paths_to_split.items():
|
|
|
|
|
+ old_tensor = old_checkpoint[path]
|
|
|
|
|
+ channels = old_tensor.shape[0] // 3
|
|
|
|
|
+
|
|
|
|
|
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
|
|
|
|
+
|
|
|
|
|
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
|
|
|
|
+
|
|
|
|
|
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
|
|
|
|
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
|
|
|
|
+
|
|
|
|
|
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
|
|
|
|
|
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
|
|
|
|
|
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
|
|
|
|
|
+
|
|
|
|
|
+ for path in paths:
|
|
|
|
|
+ new_path = path["new"]
|
|
|
|
|
+
|
|
|
|
|
+ # These have already been assigned
|
|
|
|
|
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # Global renaming happens here
|
|
|
|
|
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
|
|
|
|
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
|
|
|
|
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
|
|
|
|
+
|
|
|
|
|
+ if additional_replacements is not None:
|
|
|
|
|
+ for replacement in additional_replacements:
|
|
|
|
|
+ new_path = new_path.replace(replacement["old"], replacement["new"])
|
|
|
|
|
+
|
|
|
|
|
+ # proj_attn.weight has to be converted from conv 1D to linear
|
|
|
|
|
+ if "proj_attn.weight" in new_path:
|
|
|
|
|
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ checkpoint[new_path] = old_checkpoint[path["old"]]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def conv_attn_to_linear(checkpoint):
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if ".".join(key.split(".")[-2:]) in attn_keys:
|
|
|
|
|
+ if checkpoint[key].ndim > 2:
|
|
|
|
|
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
|
|
|
|
+ elif "proj_attn.weight" in key:
|
|
|
|
|
+ if checkpoint[key].ndim > 2:
|
|
|
|
|
+ checkpoint[key] = checkpoint[key][:, :, 0]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
|
|
|
|
+ """
|
|
|
|
|
+ Creates a config for the diffusers based on the config of the LDM model.
|
|
|
|
|
+ """
|
|
|
|
|
+ if controlnet:
|
|
|
|
|
+ unet_params = original_config.model.params.control_stage_config.params
|
|
|
|
|
+ else:
|
|
|
|
|
+ unet_params = original_config.model.params.unet_config.params
|
|
|
|
|
+
|
|
|
|
|
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
|
|
|
|
+
|
|
|
|
|
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
|
|
|
|
+
|
|
|
|
|
+ down_block_types = []
|
|
|
|
|
+ resolution = 1
|
|
|
|
|
+ for i in range(len(block_out_channels)):
|
|
|
|
|
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
|
|
|
|
+ down_block_types.append(block_type)
|
|
|
|
|
+ if i != len(block_out_channels) - 1:
|
|
|
|
|
+ resolution *= 2
|
|
|
|
|
+
|
|
|
|
|
+ up_block_types = []
|
|
|
|
|
+ for i in range(len(block_out_channels)):
|
|
|
|
|
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
|
|
|
|
+ up_block_types.append(block_type)
|
|
|
|
|
+ resolution //= 2
|
|
|
|
|
+
|
|
|
|
|
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
|
|
|
|
+
|
|
|
|
|
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
|
|
|
|
+ use_linear_projection = (
|
|
|
|
|
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
|
|
|
|
+ )
|
|
|
|
|
+ if use_linear_projection:
|
|
|
|
|
+ # stable diffusion 2-base-512 and 2-768
|
|
|
|
|
+ if head_dim is None:
|
|
|
|
|
+ head_dim = [5, 10, 20, 20]
|
|
|
|
|
+
|
|
|
|
|
+ class_embed_type = None
|
|
|
|
|
+ projection_class_embeddings_input_dim = None
|
|
|
|
|
+
|
|
|
|
|
+ if "num_classes" in unet_params:
|
|
|
|
|
+ if unet_params.num_classes == "sequential":
|
|
|
|
|
+ class_embed_type = "projection"
|
|
|
|
|
+ assert "adm_in_channels" in unet_params
|
|
|
|
|
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
|
|
|
|
+
|
|
|
|
|
+ config = {
|
|
|
|
|
+ "sample_size": image_size // vae_scale_factor,
|
|
|
|
|
+ "in_channels": unet_params.in_channels,
|
|
|
|
|
+ "down_block_types": tuple(down_block_types),
|
|
|
|
|
+ "block_out_channels": tuple(block_out_channels),
|
|
|
|
|
+ "layers_per_block": unet_params.num_res_blocks,
|
|
|
|
|
+ "cross_attention_dim": unet_params.context_dim,
|
|
|
|
|
+ "attention_head_dim": head_dim,
|
|
|
|
|
+ "use_linear_projection": use_linear_projection,
|
|
|
|
|
+ "class_embed_type": class_embed_type,
|
|
|
|
|
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if not controlnet:
|
|
|
|
|
+ config["out_channels"] = unet_params.out_channels
|
|
|
|
|
+ config["up_block_types"] = tuple(up_block_types)
|
|
|
|
|
+
|
|
|
|
|
+ return config
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def create_vae_diffusers_config(original_config, image_size: int):
|
|
|
|
|
+ """
|
|
|
|
|
+ Creates a config for the diffusers based on the config of the LDM model.
|
|
|
|
|
+ """
|
|
|
|
|
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
|
|
|
|
+ _ = original_config.model.params.first_stage_config.params.embed_dim
|
|
|
|
|
+
|
|
|
|
|
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
|
|
|
|
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
|
|
|
|
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
|
|
|
|
+
|
|
|
|
|
+ config = {
|
|
|
|
|
+ "sample_size": image_size,
|
|
|
|
|
+ "in_channels": vae_params.in_channels,
|
|
|
|
|
+ "out_channels": vae_params.out_ch,
|
|
|
|
|
+ "down_block_types": tuple(down_block_types),
|
|
|
|
|
+ "up_block_types": tuple(up_block_types),
|
|
|
|
|
+ "block_out_channels": tuple(block_out_channels),
|
|
|
|
|
+ "latent_channels": vae_params.z_channels,
|
|
|
|
|
+ "layers_per_block": vae_params.num_res_blocks,
|
|
|
|
|
+ }
|
|
|
|
|
+ return config
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def create_diffusers_schedular(original_config):
|
|
|
|
|
+ schedular = DDIMScheduler(
|
|
|
|
|
+ num_train_timesteps=original_config.model.params.timesteps,
|
|
|
|
|
+ beta_start=original_config.model.params.linear_start,
|
|
|
|
|
+ beta_end=original_config.model.params.linear_end,
|
|
|
|
|
+ beta_schedule="scaled_linear",
|
|
|
|
|
+ )
|
|
|
|
|
+ return schedular
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def create_ldm_bert_config(original_config):
|
|
|
|
|
+ bert_params = original_config.model.parms.cond_stage_config.params
|
|
|
|
|
+ config = LDMBertConfig(
|
|
|
|
|
+ d_model=bert_params.n_embed,
|
|
|
|
|
+ encoder_layers=bert_params.n_layer,
|
|
|
|
|
+ encoder_ffn_dim=bert_params.n_embed * 4,
|
|
|
|
|
+ )
|
|
|
|
|
+ return config
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
|
|
|
|
+ """
|
|
|
|
|
+ Takes a state dict and a config, and returns a converted checkpoint.
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ # extract state_dict for UNet
|
|
|
|
|
+ unet_state_dict = {}
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+
|
|
|
|
|
+ if controlnet:
|
|
|
|
|
+ unet_key = "control_model."
|
|
|
|
|
+ else:
|
|
|
|
|
+ unet_key = "model.diffusion_model."
|
|
|
|
|
+
|
|
|
|
|
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
|
|
|
|
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
|
|
|
|
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
|
|
|
|
+ print(
|
|
|
|
|
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
|
|
|
|
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
|
|
|
|
+ )
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if key.startswith("model.diffusion_model"):
|
|
|
|
|
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
|
|
|
|
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
|
|
|
|
+ else:
|
|
|
|
|
+ if sum(k.startswith("model_ema") for k in keys) > 100:
|
|
|
|
|
+ print(
|
|
|
|
|
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
|
|
|
|
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if key.startswith(unet_key):
|
|
|
|
|
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint = {}
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
|
|
|
|
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
|
|
|
|
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
|
|
|
|
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ if config["class_embed_type"] is None:
|
|
|
|
|
+ # No parameters to port
|
|
|
|
|
+ ...
|
|
|
|
|
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
|
|
|
|
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
|
|
|
|
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
|
|
|
|
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
|
|
|
|
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
|
|
|
|
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ if not controlnet:
|
|
|
|
|
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
|
|
|
|
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
|
|
|
|
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
|
|
|
|
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieves the keys for the input blocks only
|
|
|
|
|
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
|
|
|
|
+ input_blocks = {
|
|
|
|
|
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
|
|
|
|
+ for layer_id in range(num_input_blocks)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieves the keys for the middle blocks only
|
|
|
|
|
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
|
|
|
|
+ middle_blocks = {
|
|
|
|
|
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
|
|
|
|
+ for layer_id in range(num_middle_blocks)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieves the keys for the output blocks only
|
|
|
|
|
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
|
|
|
|
+ output_blocks = {
|
|
|
|
|
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
|
|
|
|
+ for layer_id in range(num_output_blocks)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(1, num_input_blocks):
|
|
|
|
|
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
|
|
|
|
|
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
|
|
|
|
+
|
|
|
|
|
+ resnets = [
|
|
|
|
|
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
|
|
|
|
+ ]
|
|
|
|
|
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
|
|
|
|
+
|
|
|
|
|
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
|
|
|
|
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_blocks.{i}.0.op.weight"
|
|
|
|
|
+ )
|
|
|
|
|
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_blocks.{i}.0.op.bias"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ paths = renew_resnet_paths(resnets)
|
|
|
|
|
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
|
+ assign_to_checkpoint(
|
|
|
|
|
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if len(attentions):
|
|
|
|
|
+ paths = renew_attention_paths(attentions)
|
|
|
|
|
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
|
|
|
|
+ assign_to_checkpoint(
|
|
|
|
|
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ resnet_0 = middle_blocks[0]
|
|
|
|
|
+ attentions = middle_blocks[1]
|
|
|
|
|
+ resnet_1 = middle_blocks[2]
|
|
|
|
|
+
|
|
|
|
|
+ resnet_0_paths = renew_resnet_paths(resnet_0)
|
|
|
|
|
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
|
+
|
|
|
|
|
+ resnet_1_paths = renew_resnet_paths(resnet_1)
|
|
|
|
|
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
|
|
|
|
+
|
|
|
|
|
+ attentions_paths = renew_attention_paths(attentions)
|
|
|
|
|
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
|
|
|
|
+ assign_to_checkpoint(
|
|
|
|
|
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(num_output_blocks):
|
|
|
|
|
+ block_id = i // (config["layers_per_block"] + 1)
|
|
|
|
|
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
|
|
|
|
|
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
|
|
|
|
+ output_block_list = {}
|
|
|
|
|
+
|
|
|
|
|
+ for layer in output_block_layers:
|
|
|
|
|
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
|
|
|
|
+ if layer_id in output_block_list:
|
|
|
|
|
+ output_block_list[layer_id].append(layer_name)
|
|
|
|
|
+ else:
|
|
|
|
|
+ output_block_list[layer_id] = [layer_name]
|
|
|
|
|
+
|
|
|
|
|
+ if len(output_block_list) > 1:
|
|
|
|
|
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
|
|
|
|
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
|
|
|
|
+
|
|
|
|
|
+ resnet_0_paths = renew_resnet_paths(resnets)
|
|
|
|
|
+ paths = renew_resnet_paths(resnets)
|
|
|
|
|
+
|
|
|
|
|
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
|
|
|
|
+ assign_to_checkpoint(
|
|
|
|
|
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
|
|
|
|
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
|
|
|
|
|
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
|
|
|
|
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
|
|
|
|
+ f"output_blocks.{i}.{index}.conv.weight"
|
|
|
|
|
+ ]
|
|
|
|
|
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
|
|
|
|
+ f"output_blocks.{i}.{index}.conv.bias"
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ # Clear attentions as they have been attributed above.
|
|
|
|
|
+ if len(attentions) == 2:
|
|
|
|
|
+ attentions = []
|
|
|
|
|
+
|
|
|
|
|
+ if len(attentions):
|
|
|
|
|
+ paths = renew_attention_paths(attentions)
|
|
|
|
|
+ meta_path = {
|
|
|
|
|
+ "old": f"output_blocks.{i}.1",
|
|
|
|
|
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
|
|
|
|
+ }
|
|
|
|
|
+ assign_to_checkpoint(
|
|
|
|
|
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
|
|
|
|
+ for path in resnet_0_paths:
|
|
|
|
|
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
|
|
|
|
|
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint[new_path] = unet_state_dict[old_path]
|
|
|
|
|
+
|
|
|
|
|
+ if controlnet:
|
|
|
|
|
+ # conditioning embedding
|
|
|
|
|
+
|
|
|
|
|
+ orig_index = 0
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.weight"
|
|
|
|
|
+ )
|
|
|
|
|
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.bias"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ orig_index += 2
|
|
|
|
|
+
|
|
|
|
|
+ diffusers_index = 0
|
|
|
|
|
+
|
|
|
|
|
+ while diffusers_index < 6:
|
|
|
|
|
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.weight"
|
|
|
|
|
+ )
|
|
|
|
|
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.bias"
|
|
|
|
|
+ )
|
|
|
|
|
+ diffusers_index += 1
|
|
|
|
|
+ orig_index += 2
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.weight"
|
|
|
|
|
+ )
|
|
|
|
|
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
|
|
|
|
+ f"input_hint_block.{orig_index}.bias"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # down blocks
|
|
|
|
|
+ for i in range(num_input_blocks):
|
|
|
|
|
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
|
|
|
|
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
|
|
|
|
+
|
|
|
|
|
+ # mid block
|
|
|
|
|
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
|
|
|
|
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
|
|
|
|
+
|
|
|
|
|
+ return new_checkpoint
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
|
|
|
+ # extract state dict for VAE
|
|
|
|
|
+ vae_state_dict = {}
|
|
|
|
|
+ vae_key = "first_stage_model."
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if key.startswith(vae_key):
|
|
|
|
|
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint = {}
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
|
|
|
|
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
|
|
|
|
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
|
|
|
|
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
|
|
|
|
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
|
|
|
|
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
|
|
|
|
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
|
|
|
|
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
|
|
|
|
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
|
|
|
|
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
|
|
|
|
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
|
|
|
|
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
|
|
|
|
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
|
|
|
|
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieves the keys for the encoder down blocks only
|
|
|
|
|
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
|
|
|
|
+ down_blocks = {
|
|
|
|
|
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ # Retrieves the keys for the decoder up blocks only
|
|
|
|
|
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
|
|
|
|
+ up_blocks = {
|
|
|
|
|
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(num_down_blocks):
|
|
|
|
|
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
|
|
|
|
+
|
|
|
|
|
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
|
|
|
|
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
|
|
|
|
+ f"encoder.down.{i}.downsample.conv.weight"
|
|
|
|
|
+ )
|
|
|
|
|
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
|
|
|
|
+ f"encoder.down.{i}.downsample.conv.bias"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ paths = renew_vae_resnet_paths(resnets)
|
|
|
|
|
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+
|
|
|
|
|
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
|
|
|
|
+ num_mid_res_blocks = 2
|
|
|
|
|
+ for i in range(1, num_mid_res_blocks + 1):
|
|
|
|
|
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
|
|
|
|
+
|
|
|
|
|
+ paths = renew_vae_resnet_paths(resnets)
|
|
|
|
|
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+
|
|
|
|
|
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
|
|
|
|
+ paths = renew_vae_attention_paths(mid_attentions)
|
|
|
|
|
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+ conv_attn_to_linear(new_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+ for i in range(num_up_blocks):
|
|
|
|
|
+ block_id = num_up_blocks - 1 - i
|
|
|
|
|
+ resnets = [
|
|
|
|
|
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
|
|
|
|
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
|
|
|
|
+ f"decoder.up.{block_id}.upsample.conv.weight"
|
|
|
|
|
+ ]
|
|
|
|
|
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
|
|
|
|
+ f"decoder.up.{block_id}.upsample.conv.bias"
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ paths = renew_vae_resnet_paths(resnets)
|
|
|
|
|
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+
|
|
|
|
|
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
|
|
|
|
+ num_mid_res_blocks = 2
|
|
|
|
|
+ for i in range(1, num_mid_res_blocks + 1):
|
|
|
|
|
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
|
|
|
|
+
|
|
|
|
|
+ paths = renew_vae_resnet_paths(resnets)
|
|
|
|
|
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+
|
|
|
|
|
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
|
|
|
|
+ paths = renew_vae_attention_paths(mid_attentions)
|
|
|
|
|
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
|
|
|
|
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
|
|
|
|
+ conv_attn_to_linear(new_checkpoint)
|
|
|
|
|
+ return new_checkpoint
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_ldm_bert_checkpoint(checkpoint, config):
|
|
|
|
|
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
|
|
|
|
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
|
|
|
|
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
|
|
|
|
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
|
|
|
|
+
|
|
|
|
|
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
|
|
|
|
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
|
|
|
|
+
|
|
|
|
|
+ def _copy_linear(hf_linear, pt_linear):
|
|
|
|
|
+ hf_linear.weight = pt_linear.weight
|
|
|
|
|
+ hf_linear.bias = pt_linear.bias
|
|
|
|
|
+
|
|
|
|
|
+ def _copy_layer(hf_layer, pt_layer):
|
|
|
|
|
+ # copy layer norms
|
|
|
|
|
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
|
|
|
|
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
|
|
|
|
+
|
|
|
|
|
+ # copy attn
|
|
|
|
|
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
|
|
|
|
+
|
|
|
|
|
+ # copy MLP
|
|
|
|
|
+ pt_mlp = pt_layer[1][1]
|
|
|
|
|
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
|
|
|
|
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
|
|
|
|
+
|
|
|
|
|
+ def _copy_layers(hf_layers, pt_layers):
|
|
|
|
|
+ for i, hf_layer in enumerate(hf_layers):
|
|
|
|
|
+ if i != 0:
|
|
|
|
|
+ i += i
|
|
|
|
|
+ pt_layer = pt_layers[i : i + 2]
|
|
|
|
|
+ _copy_layer(hf_layer, pt_layer)
|
|
|
|
|
+
|
|
|
|
|
+ hf_model = LDMBertModel(config).eval()
|
|
|
|
|
+
|
|
|
|
|
+ # copy embeds
|
|
|
|
|
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
|
|
|
|
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
|
|
|
|
+
|
|
|
|
|
+ # copy layer norm
|
|
|
|
|
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
|
|
|
|
+
|
|
|
|
|
+ # copy hidden layers
|
|
|
|
|
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
|
|
|
|
+
|
|
|
|
|
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
|
|
|
|
+
|
|
|
|
|
+ return hf_model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_ldm_clip_checkpoint(checkpoint):
|
|
|
|
|
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+ text_model = CLIPTextModel.from_pretrained("/mnt/petrelfs/guoyuwei/projects/huggingface/clip-vit-large-patch14")
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+
|
|
|
|
|
+ text_model_dict = {}
|
|
|
|
|
+
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if key.startswith("cond_stage_model.transformer"):
|
|
|
|
|
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
|
|
|
|
+
|
|
|
|
|
+ text_model.load_state_dict(text_model_dict)
|
|
|
|
|
+
|
|
|
|
|
+ return text_model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+textenc_conversion_lst = [
|
|
|
|
|
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
|
|
|
|
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
|
|
|
|
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
|
|
|
|
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
|
|
|
|
+]
|
|
|
|
|
+textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
|
|
|
|
+
|
|
|
|
|
+textenc_transformer_conversion_lst = [
|
|
|
|
|
+ # (stable-diffusion, HF Diffusers)
|
|
|
|
|
+ ("resblocks.", "text_model.encoder.layers."),
|
|
|
|
|
+ ("ln_1", "layer_norm1"),
|
|
|
|
|
+ ("ln_2", "layer_norm2"),
|
|
|
|
|
+ (".c_fc.", ".fc1."),
|
|
|
|
|
+ (".c_proj.", ".fc2."),
|
|
|
|
|
+ (".attn", ".self_attn"),
|
|
|
|
|
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
|
|
|
|
|
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
|
|
|
|
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
|
|
|
|
+]
|
|
|
|
|
+protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
|
|
|
|
+textenc_pattern = re.compile("|".join(protected.keys()))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_paint_by_example_checkpoint(checkpoint):
|
|
|
|
|
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+ model = PaintByExampleImageEncoder(config)
|
|
|
|
|
+
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+
|
|
|
|
|
+ text_model_dict = {}
|
|
|
|
|
+
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if key.startswith("cond_stage_model.transformer"):
|
|
|
|
|
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
|
|
|
|
+
|
|
|
|
|
+ # load clip vision
|
|
|
|
|
+ model.model.load_state_dict(text_model_dict)
|
|
|
|
|
+
|
|
|
|
|
+ # load mapper
|
|
|
|
|
+ keys_mapper = {
|
|
|
|
|
+ k[len("cond_stage_model.mapper.res") :]: v
|
|
|
|
|
+ for k, v in checkpoint.items()
|
|
|
|
|
+ if k.startswith("cond_stage_model.mapper")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ MAPPING = {
|
|
|
|
|
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
|
|
|
|
+ "attn.c_proj": ["attn1.to_out.0"],
|
|
|
|
|
+ "ln_1": ["norm1"],
|
|
|
|
|
+ "ln_2": ["norm3"],
|
|
|
|
|
+ "mlp.c_fc": ["ff.net.0.proj"],
|
|
|
|
|
+ "mlp.c_proj": ["ff.net.2"],
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ mapped_weights = {}
|
|
|
|
|
+ for key, value in keys_mapper.items():
|
|
|
|
|
+ prefix = key[: len("blocks.i")]
|
|
|
|
|
+ suffix = key.split(prefix)[-1].split(".")[-1]
|
|
|
|
|
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
|
|
|
|
+ mapped_names = MAPPING[name]
|
|
|
|
|
+
|
|
|
|
|
+ num_splits = len(mapped_names)
|
|
|
|
|
+ for i, mapped_name in enumerate(mapped_names):
|
|
|
|
|
+ new_name = ".".join([prefix, mapped_name, suffix])
|
|
|
|
|
+ shape = value.shape[0] // num_splits
|
|
|
|
|
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
|
|
|
|
+
|
|
|
|
|
+ model.mapper.load_state_dict(mapped_weights)
|
|
|
|
|
+
|
|
|
|
|
+ # load final layer norm
|
|
|
|
|
+ model.final_layer_norm.load_state_dict(
|
|
|
|
|
+ {
|
|
|
|
|
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
|
|
|
|
|
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # load final proj
|
|
|
|
|
+ model.proj_out.load_state_dict(
|
|
|
|
|
+ {
|
|
|
|
|
+ "bias": checkpoint["proj_out.bias"],
|
|
|
|
|
+ "weight": checkpoint["proj_out.weight"],
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # load uncond vector
|
|
|
|
|
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
|
|
|
|
+ return model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_open_clip_checkpoint(checkpoint):
|
|
|
|
|
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
|
|
|
|
+
|
|
|
|
|
+ keys = list(checkpoint.keys())
|
|
|
|
|
+
|
|
|
|
|
+ text_model_dict = {}
|
|
|
|
|
+
|
|
|
|
|
+ if "cond_stage_model.model.text_projection" in checkpoint:
|
|
|
|
|
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
|
|
|
|
+ else:
|
|
|
|
|
+ d_model = 1024
|
|
|
|
|
+
|
|
|
|
|
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
|
|
|
|
+
|
|
|
|
|
+ for key in keys:
|
|
|
|
|
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
|
|
|
|
+ continue
|
|
|
|
|
+ if key in textenc_conversion_map:
|
|
|
|
|
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
|
|
|
|
+ if key.startswith("cond_stage_model.model.transformer."):
|
|
|
|
|
+ new_key = key[len("cond_stage_model.model.transformer.") :]
|
|
|
|
|
+ if new_key.endswith(".in_proj_weight"):
|
|
|
|
|
+ new_key = new_key[: -len(".in_proj_weight")]
|
|
|
|
|
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
|
|
|
|
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
|
|
|
|
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
|
|
|
|
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
|
|
|
|
+ elif new_key.endswith(".in_proj_bias"):
|
|
|
|
|
+ new_key = new_key[: -len(".in_proj_bias")]
|
|
|
|
|
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
|
|
|
|
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
|
|
|
|
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
|
|
|
|
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
|
|
|
|
+ else:
|
|
|
|
|
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
|
|
|
|
+
|
|
|
|
|
+ text_model_dict[new_key] = checkpoint[key]
|
|
|
|
|
+
|
|
|
|
|
+ text_model.load_state_dict(text_model_dict)
|
|
|
|
|
+
|
|
|
|
|
+ return text_model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def stable_unclip_image_encoder(original_config):
|
|
|
|
|
+ """
|
|
|
|
|
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
|
|
|
|
+
|
|
|
|
|
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
|
|
|
|
+ encoders.
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ image_embedder_config = original_config.model.params.embedder_config
|
|
|
|
|
+
|
|
|
|
|
+ sd_clip_image_embedder_class = image_embedder_config.target
|
|
|
|
|
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
|
|
|
|
+
|
|
|
|
|
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
|
|
|
|
+ clip_model_name = image_embedder_config.params.model
|
|
|
|
|
+
|
|
|
|
|
+ if clip_model_name == "ViT-L/14":
|
|
|
|
|
+ feature_extractor = CLIPImageProcessor()
|
|
|
|
|
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
|
|
|
|
+
|
|
|
|
|
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
|
|
|
|
+ feature_extractor = CLIPImageProcessor()
|
|
|
|
|
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(
|
|
|
|
|
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return feature_extractor, image_encoder
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def stable_unclip_image_noising_components(
|
|
|
|
|
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
|
|
|
|
+):
|
|
|
|
|
+ """
|
|
|
|
|
+ Returns the noising components for the img2img and txt2img unclip pipelines.
|
|
|
|
|
+
|
|
|
|
|
+ Converts the stability noise augmentor into
|
|
|
|
|
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
|
|
|
|
+ 2. a `DDPMScheduler` for holding the noise schedule
|
|
|
|
|
+
|
|
|
|
|
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
|
|
|
|
+ """
|
|
|
|
|
+ noise_aug_config = original_config.model.params.noise_aug_config
|
|
|
|
|
+ noise_aug_class = noise_aug_config.target
|
|
|
|
|
+ noise_aug_class = noise_aug_class.split(".")[-1]
|
|
|
|
|
+
|
|
|
|
|
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
|
|
|
|
+ noise_aug_config = noise_aug_config.params
|
|
|
|
|
+ embedding_dim = noise_aug_config.timestep_dim
|
|
|
|
|
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
|
|
|
|
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
|
|
|
|
+
|
|
|
|
|
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
|
|
|
|
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
|
|
|
|
+
|
|
|
|
|
+ if "clip_stats_path" in noise_aug_config:
|
|
|
|
|
+ if clip_stats_path is None:
|
|
|
|
|
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
|
|
|
|
+
|
|
|
|
|
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
|
|
|
|
+ clip_mean = clip_mean[None, :]
|
|
|
|
|
+ clip_std = clip_std[None, :]
|
|
|
|
|
+
|
|
|
|
|
+ clip_stats_state_dict = {
|
|
|
|
|
+ "mean": clip_mean,
|
|
|
|
|
+ "std": clip_std,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ image_normalizer.load_state_dict(clip_stats_state_dict)
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
|
|
|
|
+
|
|
|
|
|
+ return image_normalizer, image_noising_scheduler
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def convert_controlnet_checkpoint(
|
|
|
|
|
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
|
|
|
|
+):
|
|
|
|
|
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
|
|
|
|
+ ctrlnet_config["upcast_attention"] = upcast_attention
|
|
|
|
|
+
|
|
|
|
|
+ ctrlnet_config.pop("sample_size")
|
|
|
|
|
+
|
|
|
|
|
+ controlnet_model = ControlNetModel(**ctrlnet_config)
|
|
|
|
|
+
|
|
|
|
|
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
|
|
|
|
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+ return controlnet_model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def download_from_original_stable_diffusion_ckpt(
|
|
|
|
|
+# checkpoint_path: str,
|
|
|
|
|
+# original_config_file: str = None,
|
|
|
|
|
+# image_size: int = 512,
|
|
|
|
|
+# prediction_type: str = None,
|
|
|
|
|
+# model_type: str = None,
|
|
|
|
|
+# extract_ema: bool = False,
|
|
|
|
|
+# scheduler_type: str = "pndm",
|
|
|
|
|
+# num_in_channels: Optional[int] = None,
|
|
|
|
|
+# upcast_attention: Optional[bool] = None,
|
|
|
|
|
+# device: str = None,
|
|
|
|
|
+# from_safetensors: bool = False,
|
|
|
|
|
+# stable_unclip: Optional[str] = None,
|
|
|
|
|
+# stable_unclip_prior: Optional[str] = None,
|
|
|
|
|
+# clip_stats_path: Optional[str] = None,
|
|
|
|
|
+# controlnet: Optional[bool] = None,
|
|
|
|
|
+# load_safety_checker: bool = True,
|
|
|
|
|
+# pipeline_class: DiffusionPipeline = None,
|
|
|
|
|
+# ) -> DiffusionPipeline:
|
|
|
|
|
+# """
|
|
|
|
|
+# Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
|
|
|
|
+# config file.
|
|
|
|
|
+
|
|
|
|
|
+# Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
|
|
|
|
|
+# global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
|
|
|
|
|
+# recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
|
|
|
|
+
|
|
|
|
|
+# Args:
|
|
|
|
|
+# checkpoint_path (`str`): Path to `.ckpt` file.
|
|
|
|
|
+# original_config_file (`str`):
|
|
|
|
|
+# Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
|
|
|
|
+# inferred by looking for a key that only exists in SD2.0 models.
|
|
|
|
|
+# image_size (`int`, *optional*, defaults to 512):
|
|
|
|
|
+# The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
|
|
|
|
+# Base. Use 768 for Stable Diffusion v2.
|
|
|
|
|
+# prediction_type (`str`, *optional*):
|
|
|
|
|
+# The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
|
|
|
|
+# Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
|
|
|
|
+# num_in_channels (`int`, *optional*, defaults to None):
|
|
|
|
|
+# The number of input channels. If `None`, it will be automatically inferred.
|
|
|
|
|
+# scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
|
|
|
|
+# Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
|
|
|
|
+# "ddim"]`.
|
|
|
|
|
+# model_type (`str`, *optional*, defaults to `None`):
|
|
|
|
|
+# The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
|
|
|
|
+# "FrozenCLIPEmbedder", "PaintByExample"]`.
|
|
|
|
|
+# is_img2img (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
+# Whether the model should be loaded as an img2img pipeline.
|
|
|
|
|
+# extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
|
|
|
|
+# checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
|
|
|
|
+# `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
|
|
|
|
+# inference. Non-EMA weights are usually better to continue fine-tuning.
|
|
|
|
|
+# upcast_attention (`bool`, *optional*, defaults to `None`):
|
|
|
|
|
+# Whether the attention computation should always be upcasted. This is necessary when running stable
|
|
|
|
|
+# diffusion 2.1.
|
|
|
|
|
+# device (`str`, *optional*, defaults to `None`):
|
|
|
|
|
+# The device to use. Pass `None` to determine automatically.
|
|
|
|
|
+# from_safetensors (`str`, *optional*, defaults to `False`):
|
|
|
|
|
+# If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
|
|
|
|
+# load_safety_checker (`bool`, *optional*, defaults to `True`):
|
|
|
|
|
+# Whether to load the safety checker or not. Defaults to `True`.
|
|
|
|
|
+# pipeline_class (`str`, *optional*, defaults to `None`):
|
|
|
|
|
+# The pipeline class to use. Pass `None` to determine automatically.
|
|
|
|
|
+# return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
|
|
|
|
+# """
|
|
|
|
|
+
|
|
|
|
|
+# # import pipelines here to avoid circular import error when using from_ckpt method
|
|
|
|
|
+# from diffusers import (
|
|
|
|
|
+# LDMTextToImagePipeline,
|
|
|
|
|
+# PaintByExamplePipeline,
|
|
|
|
|
+# StableDiffusionControlNetPipeline,
|
|
|
|
|
+# StableDiffusionPipeline,
|
|
|
|
|
+# StableUnCLIPImg2ImgPipeline,
|
|
|
|
|
+# StableUnCLIPPipeline,
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# if pipeline_class is None:
|
|
|
|
|
+# pipeline_class = StableDiffusionPipeline
|
|
|
|
|
+
|
|
|
|
|
+# if prediction_type == "v-prediction":
|
|
|
|
|
+# prediction_type = "v_prediction"
|
|
|
|
|
+
|
|
|
|
|
+# if not is_omegaconf_available():
|
|
|
|
|
+# raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
|
|
|
|
+
|
|
|
|
|
+# from omegaconf import OmegaConf
|
|
|
|
|
+
|
|
|
|
|
+# if from_safetensors:
|
|
|
|
|
+# if not is_safetensors_available():
|
|
|
|
|
+# raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
|
|
|
+
|
|
|
|
|
+# from safetensors import safe_open
|
|
|
|
|
+
|
|
|
|
|
+# checkpoint = {}
|
|
|
|
|
+# with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
|
|
|
|
+# for key in f.keys():
|
|
|
|
|
+# checkpoint[key] = f.get_tensor(key)
|
|
|
|
|
+# else:
|
|
|
|
|
+# if device is None:
|
|
|
|
|
+# device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
+# checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
+# else:
|
|
|
|
|
+# checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
+
|
|
|
|
|
+# # Sometimes models don't have the global_step item
|
|
|
|
|
+# if "global_step" in checkpoint:
|
|
|
|
|
+# global_step = checkpoint["global_step"]
|
|
|
|
|
+# else:
|
|
|
|
|
+# print("global_step key not found in model")
|
|
|
|
|
+# global_step = None
|
|
|
|
|
+
|
|
|
|
|
+# # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
|
|
|
|
+# # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
|
|
|
|
+# while "state_dict" in checkpoint:
|
|
|
|
|
+# checkpoint = checkpoint["state_dict"]
|
|
|
|
|
+
|
|
|
|
|
+# if original_config_file is None:
|
|
|
|
|
+# key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
|
|
|
|
+
|
|
|
|
|
+# # model_type = "v1"
|
|
|
|
|
+# config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
|
|
|
|
+
|
|
|
|
|
+# if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
|
|
|
|
+# # model_type = "v2"
|
|
|
|
|
+# config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
|
|
|
|
+
|
|
|
|
|
+# if global_step == 110000:
|
|
|
|
|
+# # v2.1 needs to upcast attention
|
|
|
|
|
+# upcast_attention = True
|
|
|
|
|
+
|
|
|
|
|
+# original_config_file = BytesIO(requests.get(config_url).content)
|
|
|
|
|
+
|
|
|
|
|
+# original_config = OmegaConf.load(original_config_file)
|
|
|
|
|
+
|
|
|
|
|
+# if num_in_channels is not None:
|
|
|
|
|
+# original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
|
|
|
|
+
|
|
|
|
|
+# if (
|
|
|
|
|
+# "parameterization" in original_config["model"]["params"]
|
|
|
|
|
+# and original_config["model"]["params"]["parameterization"] == "v"
|
|
|
|
|
+# ):
|
|
|
|
|
+# if prediction_type is None:
|
|
|
|
|
+# # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
|
|
|
|
+# # as it relies on a brittle global step parameter here
|
|
|
|
|
+# prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
|
|
|
|
+# if image_size is None:
|
|
|
|
|
+# # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
|
|
|
|
+# # as it relies on a brittle global step parameter here
|
|
|
|
|
+# image_size = 512 if global_step == 875000 else 768
|
|
|
|
|
+# else:
|
|
|
|
|
+# if prediction_type is None:
|
|
|
|
|
+# prediction_type = "epsilon"
|
|
|
|
|
+# if image_size is None:
|
|
|
|
|
+# image_size = 512
|
|
|
|
|
+
|
|
|
|
|
+# if controlnet is None:
|
|
|
|
|
+# controlnet = "control_stage_config" in original_config.model.params
|
|
|
|
|
+
|
|
|
|
|
+# if controlnet:
|
|
|
|
|
+# controlnet_model = convert_controlnet_checkpoint(
|
|
|
|
|
+# checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# num_train_timesteps = original_config.model.params.timesteps
|
|
|
|
|
+# beta_start = original_config.model.params.linear_start
|
|
|
|
|
+# beta_end = original_config.model.params.linear_end
|
|
|
|
|
+
|
|
|
|
|
+# scheduler = DDIMScheduler(
|
|
|
|
|
+# beta_end=beta_end,
|
|
|
|
|
+# beta_schedule="scaled_linear",
|
|
|
|
|
+# beta_start=beta_start,
|
|
|
|
|
+# num_train_timesteps=num_train_timesteps,
|
|
|
|
|
+# steps_offset=1,
|
|
|
|
|
+# clip_sample=False,
|
|
|
|
|
+# set_alpha_to_one=False,
|
|
|
|
|
+# prediction_type=prediction_type,
|
|
|
|
|
+# )
|
|
|
|
|
+# # make sure scheduler works correctly with DDIM
|
|
|
|
|
+# scheduler.register_to_config(clip_sample=False)
|
|
|
|
|
+
|
|
|
|
|
+# if scheduler_type == "pndm":
|
|
|
|
|
+# config = dict(scheduler.config)
|
|
|
|
|
+# config["skip_prk_steps"] = True
|
|
|
|
|
+# scheduler = PNDMScheduler.from_config(config)
|
|
|
|
|
+# elif scheduler_type == "lms":
|
|
|
|
|
+# scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
|
|
|
|
+# elif scheduler_type == "heun":
|
|
|
|
|
+# scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
|
|
|
|
+# elif scheduler_type == "euler":
|
|
|
|
|
+# scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
|
|
|
|
+# elif scheduler_type == "euler-ancestral":
|
|
|
|
|
+# scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
|
|
|
|
+# elif scheduler_type == "dpm":
|
|
|
|
|
+# scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
|
|
|
|
+# elif scheduler_type == "ddim":
|
|
|
|
|
+# scheduler = scheduler
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
|
|
|
|
+
|
|
|
|
|
+# # Convert the UNet2DConditionModel model.
|
|
|
|
|
+# unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
|
|
|
|
+# unet_config["upcast_attention"] = upcast_attention
|
|
|
|
|
+# unet = UNet2DConditionModel(**unet_config)
|
|
|
|
|
+
|
|
|
|
|
+# converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
|
|
|
|
+# checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# unet.load_state_dict(converted_unet_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+# # Convert the VAE model.
|
|
|
|
|
+# vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
|
|
|
|
+# converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
|
|
|
|
+
|
|
|
|
|
+# vae = AutoencoderKL(**vae_config)
|
|
|
|
|
+# vae.load_state_dict(converted_vae_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
+# # Convert the text model.
|
|
|
|
|
+# if model_type is None:
|
|
|
|
|
+# model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
|
|
|
|
+# logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
|
|
|
|
+
|
|
|
|
|
+# if model_type == "FrozenOpenCLIPEmbedder":
|
|
|
|
|
+# text_model = convert_open_clip_checkpoint(checkpoint)
|
|
|
|
|
+# tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
|
|
|
|
+
|
|
|
|
|
+# if stable_unclip is None:
|
|
|
|
|
+# if controlnet:
|
|
|
|
|
+# pipe = StableDiffusionControlNetPipeline(
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# controlnet=controlnet_model,
|
|
|
|
|
+# safety_checker=None,
|
|
|
|
|
+# feature_extractor=None,
|
|
|
|
|
+# requires_safety_checker=False,
|
|
|
|
|
+# )
|
|
|
|
|
+# else:
|
|
|
|
|
+# pipe = pipeline_class(
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# safety_checker=None,
|
|
|
|
|
+# feature_extractor=None,
|
|
|
|
|
+# requires_safety_checker=False,
|
|
|
|
|
+# )
|
|
|
|
|
+# else:
|
|
|
|
|
+# image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
|
|
|
|
+# original_config, clip_stats_path=clip_stats_path, device=device
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# if stable_unclip == "img2img":
|
|
|
|
|
+# feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
|
|
|
|
|
+
|
|
|
|
|
+# pipe = StableUnCLIPImg2ImgPipeline(
|
|
|
|
|
+# # image encoding components
|
|
|
|
|
+# feature_extractor=feature_extractor,
|
|
|
|
|
+# image_encoder=image_encoder,
|
|
|
|
|
+# # image noising components
|
|
|
|
|
+# image_normalizer=image_normalizer,
|
|
|
|
|
+# image_noising_scheduler=image_noising_scheduler,
|
|
|
|
|
+# # regular denoising components
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# # vae
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# )
|
|
|
|
|
+# elif stable_unclip == "txt2img":
|
|
|
|
|
+# if stable_unclip_prior is None or stable_unclip_prior == "karlo":
|
|
|
|
|
+# karlo_model = "kakaobrain/karlo-v1-alpha"
|
|
|
|
|
+# prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
|
|
|
|
|
+
|
|
|
|
|
+# prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+# prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+
|
|
|
|
|
+# prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
|
|
|
|
|
+# prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
|
|
|
|
|
+
|
|
|
|
|
+# pipe = StableUnCLIPPipeline(
|
|
|
|
|
+# # prior components
|
|
|
|
|
+# prior_tokenizer=prior_tokenizer,
|
|
|
|
|
+# prior_text_encoder=prior_text_model,
|
|
|
|
|
+# prior=prior,
|
|
|
|
|
+# prior_scheduler=prior_scheduler,
|
|
|
|
|
+# # image noising components
|
|
|
|
|
+# image_normalizer=image_normalizer,
|
|
|
|
|
+# image_noising_scheduler=image_noising_scheduler,
|
|
|
|
|
+# # regular denoising components
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# # vae
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# )
|
|
|
|
|
+# else:
|
|
|
|
|
+# raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
|
|
|
|
|
+# elif model_type == "PaintByExample":
|
|
|
|
|
+# vision_model = convert_paint_by_example_checkpoint(checkpoint)
|
|
|
|
|
+# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
|
|
|
|
+# pipe = PaintByExamplePipeline(
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# image_encoder=vision_model,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# safety_checker=None,
|
|
|
|
|
+# feature_extractor=feature_extractor,
|
|
|
|
|
+# )
|
|
|
|
|
+# elif model_type == "FrozenCLIPEmbedder":
|
|
|
|
|
+# text_model = convert_ldm_clip_checkpoint(checkpoint)
|
|
|
|
|
+# # tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
|
+# tokenizer = CLIPTokenizer.from_pretrained("/mnt/petrelfs/guoyuwei/projects/huggingface/clip-vit-large-patch14")
|
|
|
|
|
+
|
|
|
|
|
+# # if load_safety_checker:
|
|
|
|
|
+# if False:
|
|
|
|
|
+# safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
|
|
|
|
+# feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
|
|
|
|
+# else:
|
|
|
|
|
+# safety_checker = None
|
|
|
|
|
+# feature_extractor = None
|
|
|
|
|
+
|
|
|
|
|
+# if controlnet:
|
|
|
|
|
+# pipe = StableDiffusionControlNetPipeline(
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# controlnet=controlnet_model,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# safety_checker=safety_checker,
|
|
|
|
|
+# feature_extractor=feature_extractor,
|
|
|
|
|
+# )
|
|
|
|
|
+# else:
|
|
|
|
|
+# pipe = pipeline_class(
|
|
|
|
|
+# vae=vae,
|
|
|
|
|
+# text_encoder=text_model,
|
|
|
|
|
+# tokenizer=tokenizer,
|
|
|
|
|
+# unet=unet,
|
|
|
|
|
+# scheduler=scheduler,
|
|
|
|
|
+# safety_checker=safety_checker,
|
|
|
|
|
+# feature_extractor=feature_extractor,
|
|
|
|
|
+# )
|
|
|
|
|
+# else:
|
|
|
|
|
+# text_config = create_ldm_bert_config(original_config)
|
|
|
|
|
+# text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
|
|
|
|
+# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
|
|
|
+# pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
|
|
|
|
+
|
|
|
|
|
+# return pipe
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# def download_controlnet_from_original_ckpt(
|
|
|
|
|
+# checkpoint_path: str,
|
|
|
|
|
+# original_config_file: str,
|
|
|
|
|
+# image_size: int = 512,
|
|
|
|
|
+# extract_ema: bool = False,
|
|
|
|
|
+# num_in_channels: Optional[int] = None,
|
|
|
|
|
+# upcast_attention: Optional[bool] = None,
|
|
|
|
|
+# device: str = None,
|
|
|
|
|
+# from_safetensors: bool = False,
|
|
|
|
|
+# ) -> DiffusionPipeline:
|
|
|
|
|
+# if not is_omegaconf_available():
|
|
|
|
|
+# raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
|
|
|
|
+
|
|
|
|
|
+# from omegaconf import OmegaConf
|
|
|
|
|
+
|
|
|
|
|
+# if from_safetensors:
|
|
|
|
|
+# if not is_safetensors_available():
|
|
|
|
|
+# raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
|
|
|
|
+
|
|
|
|
|
+# from safetensors import safe_open
|
|
|
|
|
+
|
|
|
|
|
+# checkpoint = {}
|
|
|
|
|
+# with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
|
|
|
|
+# for key in f.keys():
|
|
|
|
|
+# checkpoint[key] = f.get_tensor(key)
|
|
|
|
|
+# else:
|
|
|
|
|
+# if device is None:
|
|
|
|
|
+# device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
+# checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
+# else:
|
|
|
|
|
+# checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
|
+
|
|
|
|
|
+# # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
|
|
|
|
+# # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
|
|
|
|
+# while "state_dict" in checkpoint:
|
|
|
|
|
+# checkpoint = checkpoint["state_dict"]
|
|
|
|
|
+
|
|
|
|
|
+# original_config = OmegaConf.load(original_config_file)
|
|
|
|
|
+
|
|
|
|
|
+# if num_in_channels is not None:
|
|
|
|
|
+# original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
|
|
|
|
+
|
|
|
|
|
+# if "control_stage_config" not in original_config.model.params:
|
|
|
|
|
+# raise ValueError("`control_stage_config` not present in original config")
|
|
|
|
|
+
|
|
|
|
|
+# controlnet_model = convert_controlnet_checkpoint(
|
|
|
|
|
+# checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
|
|
|
|
+# )
|
|
|
|
|
+
|
|
|
|
|
+# return controlnet_model
|