convert_from_ckpt.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Conversion script for the Stable Diffusion checkpoints."""
  16. import re
  17. from io import BytesIO
  18. from typing import Optional
  19. import requests
  20. import torch
  21. from transformers import (
  22. AutoFeatureExtractor,
  23. BertTokenizerFast,
  24. CLIPImageProcessor,
  25. CLIPTextModel,
  26. CLIPTextModelWithProjection,
  27. CLIPTokenizer,
  28. CLIPVisionConfig,
  29. CLIPVisionModelWithProjection,
  30. )
  31. from diffusers.models import (
  32. AutoencoderKL,
  33. # ControlNetModel,
  34. PriorTransformer,
  35. UNet2DConditionModel,
  36. )
  37. from diffusers.schedulers import (
  38. DDIMScheduler,
  39. DDPMScheduler,
  40. DPMSolverMultistepScheduler,
  41. EulerAncestralDiscreteScheduler,
  42. EulerDiscreteScheduler,
  43. HeunDiscreteScheduler,
  44. LMSDiscreteScheduler,
  45. PNDMScheduler,
  46. UnCLIPScheduler,
  47. )
  48. from diffusers.utils.import_utils import BACKENDS_MAPPING
  49. def shave_segments(path, n_shave_prefix_segments=1):
  50. """
  51. Removes segments. Positive values shave the first segments, negative shave the last segments.
  52. """
  53. if n_shave_prefix_segments >= 0:
  54. return ".".join(path.split(".")[n_shave_prefix_segments:])
  55. else:
  56. return ".".join(path.split(".")[:n_shave_prefix_segments])
  57. def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
  58. """
  59. Updates paths inside resnets to the new naming scheme (local renaming)
  60. """
  61. mapping = []
  62. for old_item in old_list:
  63. new_item = old_item.replace("in_layers.0", "norm1")
  64. new_item = new_item.replace("in_layers.2", "conv1")
  65. new_item = new_item.replace("out_layers.0", "norm2")
  66. new_item = new_item.replace("out_layers.3", "conv2")
  67. new_item = new_item.replace("emb_layers.1", "time_emb_proj")
  68. new_item = new_item.replace("skip_connection", "conv_shortcut")
  69. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  70. mapping.append({"old": old_item, "new": new_item})
  71. return mapping
  72. def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
  73. """
  74. Updates paths inside resnets to the new naming scheme (local renaming)
  75. """
  76. mapping = []
  77. for old_item in old_list:
  78. new_item = old_item
  79. new_item = new_item.replace("nin_shortcut", "conv_shortcut")
  80. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  81. mapping.append({"old": old_item, "new": new_item})
  82. return mapping
  83. def renew_attention_paths(old_list, n_shave_prefix_segments=0):
  84. """
  85. Updates paths inside attentions to the new naming scheme (local renaming)
  86. """
  87. mapping = []
  88. for old_item in old_list:
  89. new_item = old_item
  90. # new_item = new_item.replace('norm.weight', 'group_norm.weight')
  91. # new_item = new_item.replace('norm.bias', 'group_norm.bias')
  92. # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
  93. # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
  94. # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  95. mapping.append({"old": old_item, "new": new_item})
  96. return mapping
  97. def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
  98. """
  99. Updates paths inside attentions to the new naming scheme (local renaming)
  100. """
  101. mapping = []
  102. for old_item in old_list:
  103. new_item = old_item
  104. new_item = new_item.replace("norm.weight", "group_norm.weight")
  105. new_item = new_item.replace("norm.bias", "group_norm.bias")
  106. new_item = new_item.replace("q.weight", "query.weight")
  107. new_item = new_item.replace("q.bias", "query.bias")
  108. new_item = new_item.replace("k.weight", "key.weight")
  109. new_item = new_item.replace("k.bias", "key.bias")
  110. new_item = new_item.replace("v.weight", "value.weight")
  111. new_item = new_item.replace("v.bias", "value.bias")
  112. new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
  113. new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
  114. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  115. mapping.append({"old": old_item, "new": new_item})
  116. return mapping
  117. def assign_to_checkpoint(
  118. paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
  119. ):
  120. """
  121. This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
  122. attention layers, and takes into account additional replacements that may arise.
  123. Assigns the weights to the new checkpoint.
  124. """
  125. assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
  126. # Splits the attention layers into three variables.
  127. if attention_paths_to_split is not None:
  128. for path, path_map in attention_paths_to_split.items():
  129. old_tensor = old_checkpoint[path]
  130. channels = old_tensor.shape[0] // 3
  131. target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
  132. num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
  133. old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
  134. query, key, value = old_tensor.split(channels // num_heads, dim=1)
  135. checkpoint[path_map["query"]] = query.reshape(target_shape)
  136. checkpoint[path_map["key"]] = key.reshape(target_shape)
  137. checkpoint[path_map["value"]] = value.reshape(target_shape)
  138. for path in paths:
  139. new_path = path["new"]
  140. # These have already been assigned
  141. if attention_paths_to_split is not None and new_path in attention_paths_to_split:
  142. continue
  143. # Global renaming happens here
  144. new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
  145. new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
  146. new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
  147. if additional_replacements is not None:
  148. for replacement in additional_replacements:
  149. new_path = new_path.replace(replacement["old"], replacement["new"])
  150. # proj_attn.weight has to be converted from conv 1D to linear
  151. if "proj_attn.weight" in new_path:
  152. checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
  153. else:
  154. checkpoint[new_path] = old_checkpoint[path["old"]]
  155. def conv_attn_to_linear(checkpoint):
  156. keys = list(checkpoint.keys())
  157. attn_keys = ["query.weight", "key.weight", "value.weight"]
  158. for key in keys:
  159. if ".".join(key.split(".")[-2:]) in attn_keys:
  160. if checkpoint[key].ndim > 2:
  161. checkpoint[key] = checkpoint[key][:, :, 0, 0]
  162. elif "proj_attn.weight" in key:
  163. if checkpoint[key].ndim > 2:
  164. checkpoint[key] = checkpoint[key][:, :, 0]
  165. def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
  166. """
  167. Creates a config for the diffusers based on the config of the LDM model.
  168. """
  169. if controlnet:
  170. unet_params = original_config.model.params.control_stage_config.params
  171. else:
  172. unet_params = original_config.model.params.unet_config.params
  173. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  174. block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
  175. down_block_types = []
  176. resolution = 1
  177. for i in range(len(block_out_channels)):
  178. block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
  179. down_block_types.append(block_type)
  180. if i != len(block_out_channels) - 1:
  181. resolution *= 2
  182. up_block_types = []
  183. for i in range(len(block_out_channels)):
  184. block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
  185. up_block_types.append(block_type)
  186. resolution //= 2
  187. vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
  188. head_dim = unet_params.num_heads if "num_heads" in unet_params else None
  189. use_linear_projection = (
  190. unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
  191. )
  192. if use_linear_projection:
  193. # stable diffusion 2-base-512 and 2-768
  194. if head_dim is None:
  195. head_dim = [5, 10, 20, 20]
  196. class_embed_type = None
  197. projection_class_embeddings_input_dim = None
  198. if "num_classes" in unet_params:
  199. if unet_params.num_classes == "sequential":
  200. class_embed_type = "projection"
  201. assert "adm_in_channels" in unet_params
  202. projection_class_embeddings_input_dim = unet_params.adm_in_channels
  203. else:
  204. raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
  205. config = {
  206. "sample_size": image_size // vae_scale_factor,
  207. "in_channels": unet_params.in_channels,
  208. "down_block_types": tuple(down_block_types),
  209. "block_out_channels": tuple(block_out_channels),
  210. "layers_per_block": unet_params.num_res_blocks,
  211. "cross_attention_dim": unet_params.context_dim,
  212. "attention_head_dim": head_dim,
  213. "use_linear_projection": use_linear_projection,
  214. "class_embed_type": class_embed_type,
  215. "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
  216. }
  217. if not controlnet:
  218. config["out_channels"] = unet_params.out_channels
  219. config["up_block_types"] = tuple(up_block_types)
  220. return config
  221. def create_vae_diffusers_config(original_config, image_size: int):
  222. """
  223. Creates a config for the diffusers based on the config of the LDM model.
  224. """
  225. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  226. _ = original_config.model.params.first_stage_config.params.embed_dim
  227. block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
  228. down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
  229. up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
  230. config = {
  231. "sample_size": image_size,
  232. "in_channels": vae_params.in_channels,
  233. "out_channels": vae_params.out_ch,
  234. "down_block_types": tuple(down_block_types),
  235. "up_block_types": tuple(up_block_types),
  236. "block_out_channels": tuple(block_out_channels),
  237. "latent_channels": vae_params.z_channels,
  238. "layers_per_block": vae_params.num_res_blocks,
  239. }
  240. return config
  241. def create_diffusers_schedular(original_config):
  242. schedular = DDIMScheduler(
  243. num_train_timesteps=original_config.model.params.timesteps,
  244. beta_start=original_config.model.params.linear_start,
  245. beta_end=original_config.model.params.linear_end,
  246. beta_schedule="scaled_linear",
  247. )
  248. return schedular
  249. def create_ldm_bert_config(original_config):
  250. bert_params = original_config.model.parms.cond_stage_config.params
  251. config = LDMBertConfig(
  252. d_model=bert_params.n_embed,
  253. encoder_layers=bert_params.n_layer,
  254. encoder_ffn_dim=bert_params.n_embed * 4,
  255. )
  256. return config
  257. def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
  258. """
  259. Takes a state dict and a config, and returns a converted checkpoint.
  260. """
  261. # extract state_dict for UNet
  262. unet_state_dict = {}
  263. keys = list(checkpoint.keys())
  264. if controlnet:
  265. unet_key = "control_model."
  266. else:
  267. unet_key = "model.diffusion_model."
  268. # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
  269. if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
  270. print(f"Checkpoint {path} has both EMA and non-EMA weights.")
  271. print(
  272. "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
  273. " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
  274. )
  275. for key in keys:
  276. if key.startswith("model.diffusion_model"):
  277. flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
  278. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
  279. else:
  280. if sum(k.startswith("model_ema") for k in keys) > 100:
  281. print(
  282. "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
  283. " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
  284. )
  285. for key in keys:
  286. if key.startswith(unet_key):
  287. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
  288. new_checkpoint = {}
  289. new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
  290. new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
  291. new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
  292. new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
  293. if config["class_embed_type"] is None:
  294. # No parameters to port
  295. ...
  296. elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
  297. new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
  298. new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
  299. new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
  300. new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
  301. else:
  302. raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
  303. new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
  304. new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
  305. if not controlnet:
  306. new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
  307. new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
  308. new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
  309. new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
  310. # Retrieves the keys for the input blocks only
  311. num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
  312. input_blocks = {
  313. layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
  314. for layer_id in range(num_input_blocks)
  315. }
  316. # Retrieves the keys for the middle blocks only
  317. num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
  318. middle_blocks = {
  319. layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
  320. for layer_id in range(num_middle_blocks)
  321. }
  322. # Retrieves the keys for the output blocks only
  323. num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
  324. output_blocks = {
  325. layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
  326. for layer_id in range(num_output_blocks)
  327. }
  328. for i in range(1, num_input_blocks):
  329. block_id = (i - 1) // (config["layers_per_block"] + 1)
  330. layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
  331. resnets = [
  332. key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
  333. ]
  334. attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
  335. if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
  336. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
  337. f"input_blocks.{i}.0.op.weight"
  338. )
  339. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
  340. f"input_blocks.{i}.0.op.bias"
  341. )
  342. paths = renew_resnet_paths(resnets)
  343. meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
  344. assign_to_checkpoint(
  345. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  346. )
  347. if len(attentions):
  348. paths = renew_attention_paths(attentions)
  349. meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
  350. assign_to_checkpoint(
  351. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  352. )
  353. resnet_0 = middle_blocks[0]
  354. attentions = middle_blocks[1]
  355. resnet_1 = middle_blocks[2]
  356. resnet_0_paths = renew_resnet_paths(resnet_0)
  357. assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
  358. resnet_1_paths = renew_resnet_paths(resnet_1)
  359. assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
  360. attentions_paths = renew_attention_paths(attentions)
  361. meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
  362. assign_to_checkpoint(
  363. attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  364. )
  365. for i in range(num_output_blocks):
  366. block_id = i // (config["layers_per_block"] + 1)
  367. layer_in_block_id = i % (config["layers_per_block"] + 1)
  368. output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
  369. output_block_list = {}
  370. for layer in output_block_layers:
  371. layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
  372. if layer_id in output_block_list:
  373. output_block_list[layer_id].append(layer_name)
  374. else:
  375. output_block_list[layer_id] = [layer_name]
  376. if len(output_block_list) > 1:
  377. resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
  378. attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
  379. resnet_0_paths = renew_resnet_paths(resnets)
  380. paths = renew_resnet_paths(resnets)
  381. meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
  382. assign_to_checkpoint(
  383. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  384. )
  385. output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
  386. if ["conv.bias", "conv.weight"] in output_block_list.values():
  387. index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
  388. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
  389. f"output_blocks.{i}.{index}.conv.weight"
  390. ]
  391. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
  392. f"output_blocks.{i}.{index}.conv.bias"
  393. ]
  394. # Clear attentions as they have been attributed above.
  395. if len(attentions) == 2:
  396. attentions = []
  397. if len(attentions):
  398. paths = renew_attention_paths(attentions)
  399. meta_path = {
  400. "old": f"output_blocks.{i}.1",
  401. "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
  402. }
  403. assign_to_checkpoint(
  404. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  405. )
  406. else:
  407. resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
  408. for path in resnet_0_paths:
  409. old_path = ".".join(["output_blocks", str(i), path["old"]])
  410. new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
  411. new_checkpoint[new_path] = unet_state_dict[old_path]
  412. if controlnet:
  413. # conditioning embedding
  414. orig_index = 0
  415. new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
  416. f"input_hint_block.{orig_index}.weight"
  417. )
  418. new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
  419. f"input_hint_block.{orig_index}.bias"
  420. )
  421. orig_index += 2
  422. diffusers_index = 0
  423. while diffusers_index < 6:
  424. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
  425. f"input_hint_block.{orig_index}.weight"
  426. )
  427. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
  428. f"input_hint_block.{orig_index}.bias"
  429. )
  430. diffusers_index += 1
  431. orig_index += 2
  432. new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
  433. f"input_hint_block.{orig_index}.weight"
  434. )
  435. new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
  436. f"input_hint_block.{orig_index}.bias"
  437. )
  438. # down blocks
  439. for i in range(num_input_blocks):
  440. new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
  441. new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
  442. # mid block
  443. new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
  444. new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
  445. return new_checkpoint
  446. def convert_ldm_vae_checkpoint(checkpoint, config):
  447. # extract state dict for VAE
  448. vae_state_dict = {}
  449. vae_key = "first_stage_model."
  450. keys = list(checkpoint.keys())
  451. for key in keys:
  452. if key.startswith(vae_key):
  453. vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
  454. new_checkpoint = {}
  455. new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
  456. new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
  457. new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
  458. new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
  459. new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
  460. new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
  461. new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
  462. new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
  463. new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
  464. new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
  465. new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
  466. new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
  467. new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
  468. new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
  469. new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
  470. new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
  471. # Retrieves the keys for the encoder down blocks only
  472. num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
  473. down_blocks = {
  474. layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
  475. }
  476. # Retrieves the keys for the decoder up blocks only
  477. num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
  478. up_blocks = {
  479. layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
  480. }
  481. for i in range(num_down_blocks):
  482. resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
  483. if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
  484. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
  485. f"encoder.down.{i}.downsample.conv.weight"
  486. )
  487. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
  488. f"encoder.down.{i}.downsample.conv.bias"
  489. )
  490. paths = renew_vae_resnet_paths(resnets)
  491. meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
  492. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  493. mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
  494. num_mid_res_blocks = 2
  495. for i in range(1, num_mid_res_blocks + 1):
  496. resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
  497. paths = renew_vae_resnet_paths(resnets)
  498. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  499. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  500. mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
  501. paths = renew_vae_attention_paths(mid_attentions)
  502. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  503. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  504. conv_attn_to_linear(new_checkpoint)
  505. for i in range(num_up_blocks):
  506. block_id = num_up_blocks - 1 - i
  507. resnets = [
  508. key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
  509. ]
  510. if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
  511. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
  512. f"decoder.up.{block_id}.upsample.conv.weight"
  513. ]
  514. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
  515. f"decoder.up.{block_id}.upsample.conv.bias"
  516. ]
  517. paths = renew_vae_resnet_paths(resnets)
  518. meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
  519. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  520. mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
  521. num_mid_res_blocks = 2
  522. for i in range(1, num_mid_res_blocks + 1):
  523. resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
  524. paths = renew_vae_resnet_paths(resnets)
  525. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  526. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  527. mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
  528. paths = renew_vae_attention_paths(mid_attentions)
  529. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  530. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  531. conv_attn_to_linear(new_checkpoint)
  532. return new_checkpoint
  533. def convert_ldm_bert_checkpoint(checkpoint, config):
  534. def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
  535. hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
  536. hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
  537. hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
  538. hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
  539. hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
  540. def _copy_linear(hf_linear, pt_linear):
  541. hf_linear.weight = pt_linear.weight
  542. hf_linear.bias = pt_linear.bias
  543. def _copy_layer(hf_layer, pt_layer):
  544. # copy layer norms
  545. _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
  546. _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
  547. # copy attn
  548. _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
  549. # copy MLP
  550. pt_mlp = pt_layer[1][1]
  551. _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
  552. _copy_linear(hf_layer.fc2, pt_mlp.net[2])
  553. def _copy_layers(hf_layers, pt_layers):
  554. for i, hf_layer in enumerate(hf_layers):
  555. if i != 0:
  556. i += i
  557. pt_layer = pt_layers[i : i + 2]
  558. _copy_layer(hf_layer, pt_layer)
  559. hf_model = LDMBertModel(config).eval()
  560. # copy embeds
  561. hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
  562. hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
  563. # copy layer norm
  564. _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
  565. # copy hidden layers
  566. _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
  567. _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
  568. return hf_model
  569. def convert_ldm_clip_checkpoint(checkpoint):
  570. text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
  571. keys = list(checkpoint.keys())
  572. text_model_dict = {}
  573. for key in keys:
  574. if key.startswith("cond_stage_model.transformer"):
  575. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  576. text_model.load_state_dict(text_model_dict)
  577. return text_model
  578. textenc_conversion_lst = [
  579. ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
  580. ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
  581. ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
  582. ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
  583. ]
  584. textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
  585. textenc_transformer_conversion_lst = [
  586. # (stable-diffusion, HF Diffusers)
  587. ("resblocks.", "text_model.encoder.layers."),
  588. ("ln_1", "layer_norm1"),
  589. ("ln_2", "layer_norm2"),
  590. (".c_fc.", ".fc1."),
  591. (".c_proj.", ".fc2."),
  592. (".attn", ".self_attn"),
  593. ("ln_final.", "transformer.text_model.final_layer_norm."),
  594. ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
  595. ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
  596. ]
  597. protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
  598. textenc_pattern = re.compile("|".join(protected.keys()))
  599. def convert_paint_by_example_checkpoint(checkpoint):
  600. config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
  601. model = PaintByExampleImageEncoder(config)
  602. keys = list(checkpoint.keys())
  603. text_model_dict = {}
  604. for key in keys:
  605. if key.startswith("cond_stage_model.transformer"):
  606. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  607. # load clip vision
  608. model.model.load_state_dict(text_model_dict)
  609. # load mapper
  610. keys_mapper = {
  611. k[len("cond_stage_model.mapper.res") :]: v
  612. for k, v in checkpoint.items()
  613. if k.startswith("cond_stage_model.mapper")
  614. }
  615. MAPPING = {
  616. "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
  617. "attn.c_proj": ["attn1.to_out.0"],
  618. "ln_1": ["norm1"],
  619. "ln_2": ["norm3"],
  620. "mlp.c_fc": ["ff.net.0.proj"],
  621. "mlp.c_proj": ["ff.net.2"],
  622. }
  623. mapped_weights = {}
  624. for key, value in keys_mapper.items():
  625. prefix = key[: len("blocks.i")]
  626. suffix = key.split(prefix)[-1].split(".")[-1]
  627. name = key.split(prefix)[-1].split(suffix)[0][1:-1]
  628. mapped_names = MAPPING[name]
  629. num_splits = len(mapped_names)
  630. for i, mapped_name in enumerate(mapped_names):
  631. new_name = ".".join([prefix, mapped_name, suffix])
  632. shape = value.shape[0] // num_splits
  633. mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
  634. model.mapper.load_state_dict(mapped_weights)
  635. # load final layer norm
  636. model.final_layer_norm.load_state_dict(
  637. {
  638. "bias": checkpoint["cond_stage_model.final_ln.bias"],
  639. "weight": checkpoint["cond_stage_model.final_ln.weight"],
  640. }
  641. )
  642. # load final proj
  643. model.proj_out.load_state_dict(
  644. {
  645. "bias": checkpoint["proj_out.bias"],
  646. "weight": checkpoint["proj_out.weight"],
  647. }
  648. )
  649. # load uncond vector
  650. model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
  651. return model
  652. def convert_open_clip_checkpoint(checkpoint):
  653. text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
  654. keys = list(checkpoint.keys())
  655. text_model_dict = {}
  656. if "cond_stage_model.model.text_projection" in checkpoint:
  657. d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
  658. else:
  659. d_model = 1024
  660. text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
  661. for key in keys:
  662. if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
  663. continue
  664. if key in textenc_conversion_map:
  665. text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
  666. if key.startswith("cond_stage_model.model.transformer."):
  667. new_key = key[len("cond_stage_model.model.transformer.") :]
  668. if new_key.endswith(".in_proj_weight"):
  669. new_key = new_key[: -len(".in_proj_weight")]
  670. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  671. text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
  672. text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
  673. text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
  674. elif new_key.endswith(".in_proj_bias"):
  675. new_key = new_key[: -len(".in_proj_bias")]
  676. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  677. text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
  678. text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
  679. text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
  680. else:
  681. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  682. text_model_dict[new_key] = checkpoint[key]
  683. text_model.load_state_dict(text_model_dict)
  684. return text_model
  685. def stable_unclip_image_encoder(original_config):
  686. """
  687. Returns the image processor and clip image encoder for the img2img unclip pipeline.
  688. We currently know of two types of stable unclip models which separately use the clip and the openclip image
  689. encoders.
  690. """
  691. image_embedder_config = original_config.model.params.embedder_config
  692. sd_clip_image_embedder_class = image_embedder_config.target
  693. sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
  694. if sd_clip_image_embedder_class == "ClipImageEmbedder":
  695. clip_model_name = image_embedder_config.params.model
  696. if clip_model_name == "ViT-L/14":
  697. feature_extractor = CLIPImageProcessor()
  698. image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
  699. else:
  700. raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
  701. elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
  702. feature_extractor = CLIPImageProcessor()
  703. image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
  704. else:
  705. raise NotImplementedError(
  706. f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
  707. )
  708. return feature_extractor, image_encoder
  709. def stable_unclip_image_noising_components(
  710. original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
  711. ):
  712. """
  713. Returns the noising components for the img2img and txt2img unclip pipelines.
  714. Converts the stability noise augmentor into
  715. 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
  716. 2. a `DDPMScheduler` for holding the noise schedule
  717. If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
  718. """
  719. noise_aug_config = original_config.model.params.noise_aug_config
  720. noise_aug_class = noise_aug_config.target
  721. noise_aug_class = noise_aug_class.split(".")[-1]
  722. if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
  723. noise_aug_config = noise_aug_config.params
  724. embedding_dim = noise_aug_config.timestep_dim
  725. max_noise_level = noise_aug_config.noise_schedule_config.timesteps
  726. beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
  727. image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
  728. image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
  729. if "clip_stats_path" in noise_aug_config:
  730. if clip_stats_path is None:
  731. raise ValueError("This stable unclip config requires a `clip_stats_path`")
  732. clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
  733. clip_mean = clip_mean[None, :]
  734. clip_std = clip_std[None, :]
  735. clip_stats_state_dict = {
  736. "mean": clip_mean,
  737. "std": clip_std,
  738. }
  739. image_normalizer.load_state_dict(clip_stats_state_dict)
  740. else:
  741. raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
  742. return image_normalizer, image_noising_scheduler
  743. def convert_controlnet_checkpoint(
  744. checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
  745. ):
  746. ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
  747. ctrlnet_config["upcast_attention"] = upcast_attention
  748. ctrlnet_config.pop("sample_size")
  749. controlnet_model = ControlNetModel(**ctrlnet_config)
  750. converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
  751. checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
  752. )
  753. controlnet_model.load_state_dict(converted_ctrl_checkpoint)
  754. return controlnet_model