123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959 |
- # coding=utf-8
- # Copyright 2023 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ Conversion script for the Stable Diffusion checkpoints."""
- import re
- from io import BytesIO
- from typing import Optional
- import requests
- import torch
- from transformers import (
- AutoFeatureExtractor,
- BertTokenizerFast,
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
- )
- from diffusers.models import (
- AutoencoderKL,
- PriorTransformer,
- UNet2DConditionModel,
- )
- from diffusers.schedulers import (
- DDIMScheduler,
- DDPMScheduler,
- DPMSolverMultistepScheduler,
- EulerAncestralDiscreteScheduler,
- EulerDiscreteScheduler,
- HeunDiscreteScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- UnCLIPScheduler,
- )
- from diffusers.utils.import_utils import BACKENDS_MAPPING
- def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
- mapping.append({"old": old_item, "new": new_item})
- return mapping
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
- mapping.append({"old": old_item, "new": new_item})
- return mapping
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
- mapping.append({"old": old_item, "new": new_item})
- return mapping
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
- mapping.append({"old": old_item, "new": new_item})
- return mapping
- def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
- ):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
- attention layers, and takes into account additional replacements that may arise.
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
- for path in paths:
- new_path = path["new"]
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
- def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
- def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- if controlnet:
- unet_params = original_config.model.params.control_stage_config.params
- else:
- unet_params = original_config.model.params.unet_config.params
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
- block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
- vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
- head_dim = unet_params.num_heads if "num_heads" in unet_params else None
- use_linear_projection = (
- unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
- )
- if use_linear_projection:
- # stable diffusion 2-base-512 and 2-768
- if head_dim is None:
- head_dim = [5, 10, 20, 20]
- class_embed_type = None
- projection_class_embeddings_input_dim = None
- if "num_classes" in unet_params:
- if unet_params.num_classes == "sequential":
- class_embed_type = "projection"
- assert "adm_in_channels" in unet_params
- projection_class_embeddings_input_dim = unet_params.adm_in_channels
- else:
- raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
- config = {
- "sample_size": image_size // vae_scale_factor,
- "in_channels": unet_params.in_channels,
- "down_block_types": tuple(down_block_types),
- "block_out_channels": tuple(block_out_channels),
- "layers_per_block": unet_params.num_res_blocks,
- "cross_attention_dim": unet_params.context_dim,
- "attention_head_dim": head_dim,
- "use_linear_projection": use_linear_projection,
- "class_embed_type": class_embed_type,
- "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
- }
- if not controlnet:
- config["out_channels"] = unet_params.out_channels
- config["up_block_types"] = tuple(up_block_types)
- return config
- def create_vae_diffusers_config(original_config, image_size: int):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
- _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
- config = {
- "sample_size": image_size,
- "in_channels": vae_params.in_channels,
- "out_channels": vae_params.out_ch,
- "down_block_types": tuple(down_block_types),
- "up_block_types": tuple(up_block_types),
- "block_out_channels": tuple(block_out_channels),
- "latent_channels": vae_params.z_channels,
- "layers_per_block": vae_params.num_res_blocks,
- }
- return config
- def create_diffusers_schedular(original_config):
- schedular = DDIMScheduler(
- num_train_timesteps=original_config.model.params.timesteps,
- beta_start=original_config.model.params.linear_start,
- beta_end=original_config.model.params.linear_end,
- beta_schedule="scaled_linear",
- )
- return schedular
- def create_ldm_bert_config(original_config):
- bert_params = original_config.model.parms.cond_stage_config.params
- config = LDMBertConfig(
- d_model=bert_params.n_embed,
- encoder_layers=bert_params.n_layer,
- encoder_ffn_dim=bert_params.n_embed * 4,
- )
- return config
- def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
- # extract state_dict for UNet
- unet_state_dict = {}
- keys = list(checkpoint.keys())
- if controlnet:
- unet_key = "control_model."
- else:
- unet_key = "model.diffusion_model."
- # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
- if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
- print(f"Checkpoint {path} has both EMA and non-EMA weights.")
- print(
- "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
- " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
- )
- for key in keys:
- if key.startswith("model.diffusion_model"):
- flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
- else:
- if sum(k.startswith("model_ema") for k in keys) > 100:
- print(
- "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
- " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
- )
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
- new_checkpoint = {}
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
- if config["class_embed_type"] is None:
- # No parameters to port
- ...
- elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
- new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
- new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
- new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
- new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
- else:
- raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
- if not controlnet:
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
- if ["conv.bias", "conv.weight"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
- new_checkpoint[new_path] = unet_state_dict[old_path]
- if controlnet:
- # conditioning embedding
- orig_index = 0
- new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.weight"
- )
- new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.bias"
- )
- orig_index += 2
- diffusers_index = 0
- while diffusers_index < 6:
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.weight"
- )
- new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.bias"
- )
- diffusers_index += 1
- orig_index += 2
- new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.weight"
- )
- new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
- f"input_hint_block.{orig_index}.bias"
- )
- # down blocks
- for i in range(num_input_blocks):
- new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
- new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
- # mid block
- new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
- new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
- return new_checkpoint
- def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
- new_checkpoint = {}
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
- def convert_ldm_bert_checkpoint(checkpoint, config):
- def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
- hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
- hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
- hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
- hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
- hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
- def _copy_linear(hf_linear, pt_linear):
- hf_linear.weight = pt_linear.weight
- hf_linear.bias = pt_linear.bias
- def _copy_layer(hf_layer, pt_layer):
- # copy layer norms
- _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
- _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
- # copy attn
- _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
- # copy MLP
- pt_mlp = pt_layer[1][1]
- _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
- _copy_linear(hf_layer.fc2, pt_mlp.net[2])
- def _copy_layers(hf_layers, pt_layers):
- for i, hf_layer in enumerate(hf_layers):
- if i != 0:
- i += i
- pt_layer = pt_layers[i : i + 2]
- _copy_layer(hf_layer, pt_layer)
- hf_model = LDMBertModel(config).eval()
- # copy embeds
- hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
- hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
- # copy layer norm
- _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
- # copy hidden layers
- _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
- _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
- return hf_model
- def convert_ldm_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
- keys = list(checkpoint.keys())
- text_model_dict = {}
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
- text_model.load_state_dict(text_model_dict)
- return text_model
- textenc_conversion_lst = [
- ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
- ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
- ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
- ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
- ]
- textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
- textenc_transformer_conversion_lst = [
- # (stable-diffusion, HF Diffusers)
- ("resblocks.", "text_model.encoder.layers."),
- ("ln_1", "layer_norm1"),
- ("ln_2", "layer_norm2"),
- (".c_fc.", ".fc1."),
- (".c_proj.", ".fc2."),
- (".attn", ".self_attn"),
- ("ln_final.", "transformer.text_model.final_layer_norm."),
- ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
- ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
- ]
- protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
- textenc_pattern = re.compile("|".join(protected.keys()))
- def convert_paint_by_example_checkpoint(checkpoint):
- config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
- model = PaintByExampleImageEncoder(config)
- keys = list(checkpoint.keys())
- text_model_dict = {}
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
- # load clip vision
- model.model.load_state_dict(text_model_dict)
- # load mapper
- keys_mapper = {
- k[len("cond_stage_model.mapper.res") :]: v
- for k, v in checkpoint.items()
- if k.startswith("cond_stage_model.mapper")
- }
- MAPPING = {
- "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
- "attn.c_proj": ["attn1.to_out.0"],
- "ln_1": ["norm1"],
- "ln_2": ["norm3"],
- "mlp.c_fc": ["ff.net.0.proj"],
- "mlp.c_proj": ["ff.net.2"],
- }
- mapped_weights = {}
- for key, value in keys_mapper.items():
- prefix = key[: len("blocks.i")]
- suffix = key.split(prefix)[-1].split(".")[-1]
- name = key.split(prefix)[-1].split(suffix)[0][1:-1]
- mapped_names = MAPPING[name]
- num_splits = len(mapped_names)
- for i, mapped_name in enumerate(mapped_names):
- new_name = ".".join([prefix, mapped_name, suffix])
- shape = value.shape[0] // num_splits
- mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
- model.mapper.load_state_dict(mapped_weights)
- # load final layer norm
- model.final_layer_norm.load_state_dict(
- {
- "bias": checkpoint["cond_stage_model.final_ln.bias"],
- "weight": checkpoint["cond_stage_model.final_ln.weight"],
- }
- )
- # load final proj
- model.proj_out.load_state_dict(
- {
- "bias": checkpoint["proj_out.bias"],
- "weight": checkpoint["proj_out.weight"],
- }
- )
- # load uncond vector
- model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
- return model
- def convert_open_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
- keys = list(checkpoint.keys())
- text_model_dict = {}
- if "cond_stage_model.model.text_projection" in checkpoint:
- d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
- else:
- d_model = 1024
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
- for key in keys:
- if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
- continue
- if key in textenc_conversion_map:
- text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
- if key.startswith("cond_stage_model.model.transformer."):
- new_key = key[len("cond_stage_model.model.transformer.") :]
- if new_key.endswith(".in_proj_weight"):
- new_key = new_key[: -len(".in_proj_weight")]
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
- text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
- text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
- text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
- elif new_key.endswith(".in_proj_bias"):
- new_key = new_key[: -len(".in_proj_bias")]
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
- text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
- text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
- text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
- else:
- new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
- text_model_dict[new_key] = checkpoint[key]
- text_model.load_state_dict(text_model_dict)
- return text_model
- def stable_unclip_image_encoder(original_config):
- """
- Returns the image processor and clip image encoder for the img2img unclip pipeline.
- We currently know of two types of stable unclip models which separately use the clip and the openclip image
- encoders.
- """
- image_embedder_config = original_config.model.params.embedder_config
- sd_clip_image_embedder_class = image_embedder_config.target
- sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
- if sd_clip_image_embedder_class == "ClipImageEmbedder":
- clip_model_name = image_embedder_config.params.model
- if clip_model_name == "ViT-L/14":
- feature_extractor = CLIPImageProcessor()
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
- else:
- raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
- elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
- feature_extractor = CLIPImageProcessor()
- image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
- else:
- raise NotImplementedError(
- f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
- )
- return feature_extractor, image_encoder
- def stable_unclip_image_noising_components(
- original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
- ):
- """
- Returns the noising components for the img2img and txt2img unclip pipelines.
- Converts the stability noise augmentor into
- 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
- 2. a `DDPMScheduler` for holding the noise schedule
- If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
- """
- noise_aug_config = original_config.model.params.noise_aug_config
- noise_aug_class = noise_aug_config.target
- noise_aug_class = noise_aug_class.split(".")[-1]
- if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
- noise_aug_config = noise_aug_config.params
- embedding_dim = noise_aug_config.timestep_dim
- max_noise_level = noise_aug_config.noise_schedule_config.timesteps
- beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
- image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
- image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
- if "clip_stats_path" in noise_aug_config:
- if clip_stats_path is None:
- raise ValueError("This stable unclip config requires a `clip_stats_path`")
- clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
- clip_mean = clip_mean[None, :]
- clip_std = clip_std[None, :]
- clip_stats_state_dict = {
- "mean": clip_mean,
- "std": clip_std,
- }
- image_normalizer.load_state_dict(clip_stats_state_dict)
- else:
- raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
- return image_normalizer, image_noising_scheduler
- def convert_controlnet_checkpoint(
- checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
- ):
- ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
- ctrlnet_config["upcast_attention"] = upcast_attention
- ctrlnet_config.pop("sample_size")
- controlnet_model = ControlNetModel(**ctrlnet_config)
- converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
- checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
- )
- controlnet_model.load_state_dict(converted_ctrl_checkpoint)
- return controlnet_model
|