convert_from_ckpt.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Conversion script for the Stable Diffusion checkpoints."""
  16. import re
  17. from io import BytesIO
  18. from typing import Optional
  19. import requests
  20. import torch
  21. from transformers import (
  22. AutoFeatureExtractor,
  23. BertTokenizerFast,
  24. CLIPImageProcessor,
  25. CLIPTextModel,
  26. CLIPTextModelWithProjection,
  27. CLIPTokenizer,
  28. CLIPVisionConfig,
  29. CLIPVisionModelWithProjection,
  30. )
  31. from diffusers.models import (
  32. AutoencoderKL,
  33. PriorTransformer,
  34. UNet2DConditionModel,
  35. )
  36. from diffusers.schedulers import (
  37. DDIMScheduler,
  38. DDPMScheduler,
  39. DPMSolverMultistepScheduler,
  40. EulerAncestralDiscreteScheduler,
  41. EulerDiscreteScheduler,
  42. HeunDiscreteScheduler,
  43. LMSDiscreteScheduler,
  44. PNDMScheduler,
  45. UnCLIPScheduler,
  46. )
  47. from diffusers.utils.import_utils import BACKENDS_MAPPING
  48. def shave_segments(path, n_shave_prefix_segments=1):
  49. """
  50. Removes segments. Positive values shave the first segments, negative shave the last segments.
  51. """
  52. if n_shave_prefix_segments >= 0:
  53. return ".".join(path.split(".")[n_shave_prefix_segments:])
  54. else:
  55. return ".".join(path.split(".")[:n_shave_prefix_segments])
  56. def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
  57. """
  58. Updates paths inside resnets to the new naming scheme (local renaming)
  59. """
  60. mapping = []
  61. for old_item in old_list:
  62. new_item = old_item.replace("in_layers.0", "norm1")
  63. new_item = new_item.replace("in_layers.2", "conv1")
  64. new_item = new_item.replace("out_layers.0", "norm2")
  65. new_item = new_item.replace("out_layers.3", "conv2")
  66. new_item = new_item.replace("emb_layers.1", "time_emb_proj")
  67. new_item = new_item.replace("skip_connection", "conv_shortcut")
  68. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  69. mapping.append({"old": old_item, "new": new_item})
  70. return mapping
  71. def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
  72. """
  73. Updates paths inside resnets to the new naming scheme (local renaming)
  74. """
  75. mapping = []
  76. for old_item in old_list:
  77. new_item = old_item
  78. new_item = new_item.replace("nin_shortcut", "conv_shortcut")
  79. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  80. mapping.append({"old": old_item, "new": new_item})
  81. return mapping
  82. def renew_attention_paths(old_list, n_shave_prefix_segments=0):
  83. """
  84. Updates paths inside attentions to the new naming scheme (local renaming)
  85. """
  86. mapping = []
  87. for old_item in old_list:
  88. new_item = old_item
  89. # new_item = new_item.replace('norm.weight', 'group_norm.weight')
  90. # new_item = new_item.replace('norm.bias', 'group_norm.bias')
  91. # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
  92. # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
  93. # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  94. mapping.append({"old": old_item, "new": new_item})
  95. return mapping
  96. def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
  97. """
  98. Updates paths inside attentions to the new naming scheme (local renaming)
  99. """
  100. mapping = []
  101. for old_item in old_list:
  102. new_item = old_item
  103. new_item = new_item.replace("norm.weight", "group_norm.weight")
  104. new_item = new_item.replace("norm.bias", "group_norm.bias")
  105. new_item = new_item.replace("q.weight", "query.weight")
  106. new_item = new_item.replace("q.bias", "query.bias")
  107. new_item = new_item.replace("k.weight", "key.weight")
  108. new_item = new_item.replace("k.bias", "key.bias")
  109. new_item = new_item.replace("v.weight", "value.weight")
  110. new_item = new_item.replace("v.bias", "value.bias")
  111. new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
  112. new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
  113. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  114. mapping.append({"old": old_item, "new": new_item})
  115. return mapping
  116. def assign_to_checkpoint(
  117. paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
  118. ):
  119. """
  120. This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
  121. attention layers, and takes into account additional replacements that may arise.
  122. Assigns the weights to the new checkpoint.
  123. """
  124. assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
  125. # Splits the attention layers into three variables.
  126. if attention_paths_to_split is not None:
  127. for path, path_map in attention_paths_to_split.items():
  128. old_tensor = old_checkpoint[path]
  129. channels = old_tensor.shape[0] // 3
  130. target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
  131. num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
  132. old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
  133. query, key, value = old_tensor.split(channels // num_heads, dim=1)
  134. checkpoint[path_map["query"]] = query.reshape(target_shape)
  135. checkpoint[path_map["key"]] = key.reshape(target_shape)
  136. checkpoint[path_map["value"]] = value.reshape(target_shape)
  137. for path in paths:
  138. new_path = path["new"]
  139. # These have already been assigned
  140. if attention_paths_to_split is not None and new_path in attention_paths_to_split:
  141. continue
  142. # Global renaming happens here
  143. new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
  144. new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
  145. new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
  146. if additional_replacements is not None:
  147. for replacement in additional_replacements:
  148. new_path = new_path.replace(replacement["old"], replacement["new"])
  149. # proj_attn.weight has to be converted from conv 1D to linear
  150. if "proj_attn.weight" in new_path:
  151. checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
  152. else:
  153. checkpoint[new_path] = old_checkpoint[path["old"]]
  154. def conv_attn_to_linear(checkpoint):
  155. keys = list(checkpoint.keys())
  156. attn_keys = ["query.weight", "key.weight", "value.weight"]
  157. for key in keys:
  158. if ".".join(key.split(".")[-2:]) in attn_keys:
  159. if checkpoint[key].ndim > 2:
  160. checkpoint[key] = checkpoint[key][:, :, 0, 0]
  161. elif "proj_attn.weight" in key:
  162. if checkpoint[key].ndim > 2:
  163. checkpoint[key] = checkpoint[key][:, :, 0]
  164. def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
  165. """
  166. Creates a config for the diffusers based on the config of the LDM model.
  167. """
  168. if controlnet:
  169. unet_params = original_config.model.params.control_stage_config.params
  170. else:
  171. unet_params = original_config.model.params.unet_config.params
  172. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  173. block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
  174. down_block_types = []
  175. resolution = 1
  176. for i in range(len(block_out_channels)):
  177. block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
  178. down_block_types.append(block_type)
  179. if i != len(block_out_channels) - 1:
  180. resolution *= 2
  181. up_block_types = []
  182. for i in range(len(block_out_channels)):
  183. block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
  184. up_block_types.append(block_type)
  185. resolution //= 2
  186. vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
  187. head_dim = unet_params.num_heads if "num_heads" in unet_params else None
  188. use_linear_projection = (
  189. unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
  190. )
  191. if use_linear_projection:
  192. # stable diffusion 2-base-512 and 2-768
  193. if head_dim is None:
  194. head_dim = [5, 10, 20, 20]
  195. class_embed_type = None
  196. projection_class_embeddings_input_dim = None
  197. if "num_classes" in unet_params:
  198. if unet_params.num_classes == "sequential":
  199. class_embed_type = "projection"
  200. assert "adm_in_channels" in unet_params
  201. projection_class_embeddings_input_dim = unet_params.adm_in_channels
  202. else:
  203. raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
  204. config = {
  205. "sample_size": image_size // vae_scale_factor,
  206. "in_channels": unet_params.in_channels,
  207. "down_block_types": tuple(down_block_types),
  208. "block_out_channels": tuple(block_out_channels),
  209. "layers_per_block": unet_params.num_res_blocks,
  210. "cross_attention_dim": unet_params.context_dim,
  211. "attention_head_dim": head_dim,
  212. "use_linear_projection": use_linear_projection,
  213. "class_embed_type": class_embed_type,
  214. "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
  215. }
  216. if not controlnet:
  217. config["out_channels"] = unet_params.out_channels
  218. config["up_block_types"] = tuple(up_block_types)
  219. return config
  220. def create_vae_diffusers_config(original_config, image_size: int):
  221. """
  222. Creates a config for the diffusers based on the config of the LDM model.
  223. """
  224. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  225. _ = original_config.model.params.first_stage_config.params.embed_dim
  226. block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
  227. down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
  228. up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
  229. config = {
  230. "sample_size": image_size,
  231. "in_channels": vae_params.in_channels,
  232. "out_channels": vae_params.out_ch,
  233. "down_block_types": tuple(down_block_types),
  234. "up_block_types": tuple(up_block_types),
  235. "block_out_channels": tuple(block_out_channels),
  236. "latent_channels": vae_params.z_channels,
  237. "layers_per_block": vae_params.num_res_blocks,
  238. }
  239. return config
  240. def create_diffusers_schedular(original_config):
  241. schedular = DDIMScheduler(
  242. num_train_timesteps=original_config.model.params.timesteps,
  243. beta_start=original_config.model.params.linear_start,
  244. beta_end=original_config.model.params.linear_end,
  245. beta_schedule="scaled_linear",
  246. )
  247. return schedular
  248. def create_ldm_bert_config(original_config):
  249. bert_params = original_config.model.parms.cond_stage_config.params
  250. config = LDMBertConfig(
  251. d_model=bert_params.n_embed,
  252. encoder_layers=bert_params.n_layer,
  253. encoder_ffn_dim=bert_params.n_embed * 4,
  254. )
  255. return config
  256. def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
  257. """
  258. Takes a state dict and a config, and returns a converted checkpoint.
  259. """
  260. # extract state_dict for UNet
  261. unet_state_dict = {}
  262. keys = list(checkpoint.keys())
  263. if controlnet:
  264. unet_key = "control_model."
  265. else:
  266. unet_key = "model.diffusion_model."
  267. # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
  268. if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
  269. print(f"Checkpoint {path} has both EMA and non-EMA weights.")
  270. print(
  271. "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
  272. " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
  273. )
  274. for key in keys:
  275. if key.startswith("model.diffusion_model"):
  276. flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
  277. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
  278. else:
  279. if sum(k.startswith("model_ema") for k in keys) > 100:
  280. print(
  281. "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
  282. " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
  283. )
  284. for key in keys:
  285. if key.startswith(unet_key):
  286. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
  287. new_checkpoint = {}
  288. new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
  289. new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
  290. new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
  291. new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
  292. if config["class_embed_type"] is None:
  293. # No parameters to port
  294. ...
  295. elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
  296. new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
  297. new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
  298. new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
  299. new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
  300. else:
  301. raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
  302. new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
  303. new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
  304. if not controlnet:
  305. new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
  306. new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
  307. new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
  308. new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
  309. # Retrieves the keys for the input blocks only
  310. num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
  311. input_blocks = {
  312. layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
  313. for layer_id in range(num_input_blocks)
  314. }
  315. # Retrieves the keys for the middle blocks only
  316. num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
  317. middle_blocks = {
  318. layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
  319. for layer_id in range(num_middle_blocks)
  320. }
  321. # Retrieves the keys for the output blocks only
  322. num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
  323. output_blocks = {
  324. layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
  325. for layer_id in range(num_output_blocks)
  326. }
  327. for i in range(1, num_input_blocks):
  328. block_id = (i - 1) // (config["layers_per_block"] + 1)
  329. layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
  330. resnets = [
  331. key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
  332. ]
  333. attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
  334. if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
  335. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
  336. f"input_blocks.{i}.0.op.weight"
  337. )
  338. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
  339. f"input_blocks.{i}.0.op.bias"
  340. )
  341. paths = renew_resnet_paths(resnets)
  342. meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
  343. assign_to_checkpoint(
  344. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  345. )
  346. if len(attentions):
  347. paths = renew_attention_paths(attentions)
  348. meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
  349. assign_to_checkpoint(
  350. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  351. )
  352. resnet_0 = middle_blocks[0]
  353. attentions = middle_blocks[1]
  354. resnet_1 = middle_blocks[2]
  355. resnet_0_paths = renew_resnet_paths(resnet_0)
  356. assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
  357. resnet_1_paths = renew_resnet_paths(resnet_1)
  358. assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
  359. attentions_paths = renew_attention_paths(attentions)
  360. meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
  361. assign_to_checkpoint(
  362. attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  363. )
  364. for i in range(num_output_blocks):
  365. block_id = i // (config["layers_per_block"] + 1)
  366. layer_in_block_id = i % (config["layers_per_block"] + 1)
  367. output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
  368. output_block_list = {}
  369. for layer in output_block_layers:
  370. layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
  371. if layer_id in output_block_list:
  372. output_block_list[layer_id].append(layer_name)
  373. else:
  374. output_block_list[layer_id] = [layer_name]
  375. if len(output_block_list) > 1:
  376. resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
  377. attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
  378. resnet_0_paths = renew_resnet_paths(resnets)
  379. paths = renew_resnet_paths(resnets)
  380. meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
  381. assign_to_checkpoint(
  382. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  383. )
  384. output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
  385. if ["conv.bias", "conv.weight"] in output_block_list.values():
  386. index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
  387. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
  388. f"output_blocks.{i}.{index}.conv.weight"
  389. ]
  390. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
  391. f"output_blocks.{i}.{index}.conv.bias"
  392. ]
  393. # Clear attentions as they have been attributed above.
  394. if len(attentions) == 2:
  395. attentions = []
  396. if len(attentions):
  397. paths = renew_attention_paths(attentions)
  398. meta_path = {
  399. "old": f"output_blocks.{i}.1",
  400. "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
  401. }
  402. assign_to_checkpoint(
  403. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  404. )
  405. else:
  406. resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
  407. for path in resnet_0_paths:
  408. old_path = ".".join(["output_blocks", str(i), path["old"]])
  409. new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
  410. new_checkpoint[new_path] = unet_state_dict[old_path]
  411. if controlnet:
  412. # conditioning embedding
  413. orig_index = 0
  414. new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
  415. f"input_hint_block.{orig_index}.weight"
  416. )
  417. new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
  418. f"input_hint_block.{orig_index}.bias"
  419. )
  420. orig_index += 2
  421. diffusers_index = 0
  422. while diffusers_index < 6:
  423. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
  424. f"input_hint_block.{orig_index}.weight"
  425. )
  426. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
  427. f"input_hint_block.{orig_index}.bias"
  428. )
  429. diffusers_index += 1
  430. orig_index += 2
  431. new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
  432. f"input_hint_block.{orig_index}.weight"
  433. )
  434. new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
  435. f"input_hint_block.{orig_index}.bias"
  436. )
  437. # down blocks
  438. for i in range(num_input_blocks):
  439. new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
  440. new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
  441. # mid block
  442. new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
  443. new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
  444. return new_checkpoint
  445. def convert_ldm_vae_checkpoint(checkpoint, config):
  446. # extract state dict for VAE
  447. vae_state_dict = {}
  448. vae_key = "first_stage_model."
  449. keys = list(checkpoint.keys())
  450. for key in keys:
  451. if key.startswith(vae_key):
  452. vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
  453. new_checkpoint = {}
  454. new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
  455. new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
  456. new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
  457. new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
  458. new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
  459. new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
  460. new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
  461. new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
  462. new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
  463. new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
  464. new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
  465. new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
  466. new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
  467. new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
  468. new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
  469. new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
  470. # Retrieves the keys for the encoder down blocks only
  471. num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
  472. down_blocks = {
  473. layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
  474. }
  475. # Retrieves the keys for the decoder up blocks only
  476. num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
  477. up_blocks = {
  478. layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
  479. }
  480. for i in range(num_down_blocks):
  481. resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
  482. if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
  483. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
  484. f"encoder.down.{i}.downsample.conv.weight"
  485. )
  486. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
  487. f"encoder.down.{i}.downsample.conv.bias"
  488. )
  489. paths = renew_vae_resnet_paths(resnets)
  490. meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
  491. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  492. mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
  493. num_mid_res_blocks = 2
  494. for i in range(1, num_mid_res_blocks + 1):
  495. resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
  496. paths = renew_vae_resnet_paths(resnets)
  497. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  498. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  499. mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
  500. paths = renew_vae_attention_paths(mid_attentions)
  501. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  502. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  503. conv_attn_to_linear(new_checkpoint)
  504. for i in range(num_up_blocks):
  505. block_id = num_up_blocks - 1 - i
  506. resnets = [
  507. key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
  508. ]
  509. if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
  510. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
  511. f"decoder.up.{block_id}.upsample.conv.weight"
  512. ]
  513. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
  514. f"decoder.up.{block_id}.upsample.conv.bias"
  515. ]
  516. paths = renew_vae_resnet_paths(resnets)
  517. meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
  518. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  519. mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
  520. num_mid_res_blocks = 2
  521. for i in range(1, num_mid_res_blocks + 1):
  522. resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
  523. paths = renew_vae_resnet_paths(resnets)
  524. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  525. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  526. mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
  527. paths = renew_vae_attention_paths(mid_attentions)
  528. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  529. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  530. conv_attn_to_linear(new_checkpoint)
  531. return new_checkpoint
  532. def convert_ldm_bert_checkpoint(checkpoint, config):
  533. def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
  534. hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
  535. hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
  536. hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
  537. hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
  538. hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
  539. def _copy_linear(hf_linear, pt_linear):
  540. hf_linear.weight = pt_linear.weight
  541. hf_linear.bias = pt_linear.bias
  542. def _copy_layer(hf_layer, pt_layer):
  543. # copy layer norms
  544. _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
  545. _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
  546. # copy attn
  547. _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
  548. # copy MLP
  549. pt_mlp = pt_layer[1][1]
  550. _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
  551. _copy_linear(hf_layer.fc2, pt_mlp.net[2])
  552. def _copy_layers(hf_layers, pt_layers):
  553. for i, hf_layer in enumerate(hf_layers):
  554. if i != 0:
  555. i += i
  556. pt_layer = pt_layers[i : i + 2]
  557. _copy_layer(hf_layer, pt_layer)
  558. hf_model = LDMBertModel(config).eval()
  559. # copy embeds
  560. hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
  561. hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
  562. # copy layer norm
  563. _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
  564. # copy hidden layers
  565. _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
  566. _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
  567. return hf_model
  568. def convert_ldm_clip_checkpoint(checkpoint):
  569. text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
  570. keys = list(checkpoint.keys())
  571. text_model_dict = {}
  572. for key in keys:
  573. if key.startswith("cond_stage_model.transformer"):
  574. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  575. text_model.load_state_dict(text_model_dict)
  576. return text_model
  577. textenc_conversion_lst = [
  578. ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
  579. ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
  580. ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
  581. ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
  582. ]
  583. textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
  584. textenc_transformer_conversion_lst = [
  585. # (stable-diffusion, HF Diffusers)
  586. ("resblocks.", "text_model.encoder.layers."),
  587. ("ln_1", "layer_norm1"),
  588. ("ln_2", "layer_norm2"),
  589. (".c_fc.", ".fc1."),
  590. (".c_proj.", ".fc2."),
  591. (".attn", ".self_attn"),
  592. ("ln_final.", "transformer.text_model.final_layer_norm."),
  593. ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
  594. ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
  595. ]
  596. protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
  597. textenc_pattern = re.compile("|".join(protected.keys()))
  598. def convert_paint_by_example_checkpoint(checkpoint):
  599. config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
  600. model = PaintByExampleImageEncoder(config)
  601. keys = list(checkpoint.keys())
  602. text_model_dict = {}
  603. for key in keys:
  604. if key.startswith("cond_stage_model.transformer"):
  605. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  606. # load clip vision
  607. model.model.load_state_dict(text_model_dict)
  608. # load mapper
  609. keys_mapper = {
  610. k[len("cond_stage_model.mapper.res") :]: v
  611. for k, v in checkpoint.items()
  612. if k.startswith("cond_stage_model.mapper")
  613. }
  614. MAPPING = {
  615. "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
  616. "attn.c_proj": ["attn1.to_out.0"],
  617. "ln_1": ["norm1"],
  618. "ln_2": ["norm3"],
  619. "mlp.c_fc": ["ff.net.0.proj"],
  620. "mlp.c_proj": ["ff.net.2"],
  621. }
  622. mapped_weights = {}
  623. for key, value in keys_mapper.items():
  624. prefix = key[: len("blocks.i")]
  625. suffix = key.split(prefix)[-1].split(".")[-1]
  626. name = key.split(prefix)[-1].split(suffix)[0][1:-1]
  627. mapped_names = MAPPING[name]
  628. num_splits = len(mapped_names)
  629. for i, mapped_name in enumerate(mapped_names):
  630. new_name = ".".join([prefix, mapped_name, suffix])
  631. shape = value.shape[0] // num_splits
  632. mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
  633. model.mapper.load_state_dict(mapped_weights)
  634. # load final layer norm
  635. model.final_layer_norm.load_state_dict(
  636. {
  637. "bias": checkpoint["cond_stage_model.final_ln.bias"],
  638. "weight": checkpoint["cond_stage_model.final_ln.weight"],
  639. }
  640. )
  641. # load final proj
  642. model.proj_out.load_state_dict(
  643. {
  644. "bias": checkpoint["proj_out.bias"],
  645. "weight": checkpoint["proj_out.weight"],
  646. }
  647. )
  648. # load uncond vector
  649. model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
  650. return model
  651. def convert_open_clip_checkpoint(checkpoint):
  652. text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
  653. keys = list(checkpoint.keys())
  654. text_model_dict = {}
  655. if "cond_stage_model.model.text_projection" in checkpoint:
  656. d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
  657. else:
  658. d_model = 1024
  659. text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
  660. for key in keys:
  661. if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
  662. continue
  663. if key in textenc_conversion_map:
  664. text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
  665. if key.startswith("cond_stage_model.model.transformer."):
  666. new_key = key[len("cond_stage_model.model.transformer.") :]
  667. if new_key.endswith(".in_proj_weight"):
  668. new_key = new_key[: -len(".in_proj_weight")]
  669. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  670. text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
  671. text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
  672. text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
  673. elif new_key.endswith(".in_proj_bias"):
  674. new_key = new_key[: -len(".in_proj_bias")]
  675. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  676. text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
  677. text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
  678. text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
  679. else:
  680. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  681. text_model_dict[new_key] = checkpoint[key]
  682. text_model.load_state_dict(text_model_dict)
  683. return text_model
  684. def stable_unclip_image_encoder(original_config):
  685. """
  686. Returns the image processor and clip image encoder for the img2img unclip pipeline.
  687. We currently know of two types of stable unclip models which separately use the clip and the openclip image
  688. encoders.
  689. """
  690. image_embedder_config = original_config.model.params.embedder_config
  691. sd_clip_image_embedder_class = image_embedder_config.target
  692. sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
  693. if sd_clip_image_embedder_class == "ClipImageEmbedder":
  694. clip_model_name = image_embedder_config.params.model
  695. if clip_model_name == "ViT-L/14":
  696. feature_extractor = CLIPImageProcessor()
  697. image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
  698. else:
  699. raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
  700. elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
  701. feature_extractor = CLIPImageProcessor()
  702. image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
  703. else:
  704. raise NotImplementedError(
  705. f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
  706. )
  707. return feature_extractor, image_encoder
  708. def stable_unclip_image_noising_components(
  709. original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
  710. ):
  711. """
  712. Returns the noising components for the img2img and txt2img unclip pipelines.
  713. Converts the stability noise augmentor into
  714. 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
  715. 2. a `DDPMScheduler` for holding the noise schedule
  716. If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
  717. """
  718. noise_aug_config = original_config.model.params.noise_aug_config
  719. noise_aug_class = noise_aug_config.target
  720. noise_aug_class = noise_aug_class.split(".")[-1]
  721. if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
  722. noise_aug_config = noise_aug_config.params
  723. embedding_dim = noise_aug_config.timestep_dim
  724. max_noise_level = noise_aug_config.noise_schedule_config.timesteps
  725. beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
  726. image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
  727. image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
  728. if "clip_stats_path" in noise_aug_config:
  729. if clip_stats_path is None:
  730. raise ValueError("This stable unclip config requires a `clip_stats_path`")
  731. clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
  732. clip_mean = clip_mean[None, :]
  733. clip_std = clip_std[None, :]
  734. clip_stats_state_dict = {
  735. "mean": clip_mean,
  736. "std": clip_std,
  737. }
  738. image_normalizer.load_state_dict(clip_stats_state_dict)
  739. else:
  740. raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
  741. return image_normalizer, image_noising_scheduler
  742. def convert_controlnet_checkpoint(
  743. checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
  744. ):
  745. ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
  746. ctrlnet_config["upcast_attention"] = upcast_attention
  747. ctrlnet_config.pop("sample_size")
  748. controlnet_model = ControlNetModel(**ctrlnet_config)
  749. converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
  750. checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
  751. )
  752. controlnet_model.load_state_dict(converted_ctrl_checkpoint)
  753. return controlnet_model