convert_from_ckpt.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ Conversion script for the Stable Diffusion checkpoints."""
  16. import re
  17. from io import BytesIO
  18. from typing import Optional
  19. import requests
  20. import torch
  21. from transformers import (
  22. AutoFeatureExtractor,
  23. BertTokenizerFast,
  24. CLIPImageProcessor,
  25. CLIPTextModel,
  26. CLIPTextModelWithProjection,
  27. CLIPTokenizer,
  28. CLIPVisionConfig,
  29. CLIPVisionModelWithProjection,
  30. )
  31. from diffusers.models import (
  32. AutoencoderKL,
  33. # ControlNetModel,
  34. PriorTransformer,
  35. UNet2DConditionModel,
  36. )
  37. from diffusers.schedulers import (
  38. DDIMScheduler,
  39. DDPMScheduler,
  40. DPMSolverMultistepScheduler,
  41. EulerAncestralDiscreteScheduler,
  42. EulerDiscreteScheduler,
  43. HeunDiscreteScheduler,
  44. LMSDiscreteScheduler,
  45. PNDMScheduler,
  46. UnCLIPScheduler,
  47. )
  48. # from diffusers.utils import is_omegaconf_available, is_safetensors_available, logging
  49. from diffusers.utils.import_utils import BACKENDS_MAPPING
  50. # from diffusers.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
  51. # from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
  52. # from diffusers.pipelines.pipeline_utils import DiffusionPipeline
  53. # from .safety_checker import StableDiffusionSafetyChecker
  54. # from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
  55. # logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  56. def shave_segments(path, n_shave_prefix_segments=1):
  57. """
  58. Removes segments. Positive values shave the first segments, negative shave the last segments.
  59. """
  60. if n_shave_prefix_segments >= 0:
  61. return ".".join(path.split(".")[n_shave_prefix_segments:])
  62. else:
  63. return ".".join(path.split(".")[:n_shave_prefix_segments])
  64. def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
  65. """
  66. Updates paths inside resnets to the new naming scheme (local renaming)
  67. """
  68. mapping = []
  69. for old_item in old_list:
  70. new_item = old_item.replace("in_layers.0", "norm1")
  71. new_item = new_item.replace("in_layers.2", "conv1")
  72. new_item = new_item.replace("out_layers.0", "norm2")
  73. new_item = new_item.replace("out_layers.3", "conv2")
  74. new_item = new_item.replace("emb_layers.1", "time_emb_proj")
  75. new_item = new_item.replace("skip_connection", "conv_shortcut")
  76. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  77. mapping.append({"old": old_item, "new": new_item})
  78. return mapping
  79. def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
  80. """
  81. Updates paths inside resnets to the new naming scheme (local renaming)
  82. """
  83. mapping = []
  84. for old_item in old_list:
  85. new_item = old_item
  86. new_item = new_item.replace("nin_shortcut", "conv_shortcut")
  87. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  88. mapping.append({"old": old_item, "new": new_item})
  89. return mapping
  90. def renew_attention_paths(old_list, n_shave_prefix_segments=0):
  91. """
  92. Updates paths inside attentions to the new naming scheme (local renaming)
  93. """
  94. mapping = []
  95. for old_item in old_list:
  96. new_item = old_item
  97. # new_item = new_item.replace('norm.weight', 'group_norm.weight')
  98. # new_item = new_item.replace('norm.bias', 'group_norm.bias')
  99. # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
  100. # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
  101. # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  102. mapping.append({"old": old_item, "new": new_item})
  103. return mapping
  104. def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
  105. """
  106. Updates paths inside attentions to the new naming scheme (local renaming)
  107. """
  108. mapping = []
  109. for old_item in old_list:
  110. new_item = old_item
  111. new_item = new_item.replace("norm.weight", "group_norm.weight")
  112. new_item = new_item.replace("norm.bias", "group_norm.bias")
  113. new_item = new_item.replace("q.weight", "query.weight")
  114. new_item = new_item.replace("q.bias", "query.bias")
  115. new_item = new_item.replace("k.weight", "key.weight")
  116. new_item = new_item.replace("k.bias", "key.bias")
  117. new_item = new_item.replace("v.weight", "value.weight")
  118. new_item = new_item.replace("v.bias", "value.bias")
  119. new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
  120. new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
  121. new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
  122. mapping.append({"old": old_item, "new": new_item})
  123. return mapping
  124. def assign_to_checkpoint(
  125. paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
  126. ):
  127. """
  128. This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
  129. attention layers, and takes into account additional replacements that may arise.
  130. Assigns the weights to the new checkpoint.
  131. """
  132. assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
  133. # Splits the attention layers into three variables.
  134. if attention_paths_to_split is not None:
  135. for path, path_map in attention_paths_to_split.items():
  136. old_tensor = old_checkpoint[path]
  137. channels = old_tensor.shape[0] // 3
  138. target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
  139. num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
  140. old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
  141. query, key, value = old_tensor.split(channels // num_heads, dim=1)
  142. checkpoint[path_map["query"]] = query.reshape(target_shape)
  143. checkpoint[path_map["key"]] = key.reshape(target_shape)
  144. checkpoint[path_map["value"]] = value.reshape(target_shape)
  145. for path in paths:
  146. new_path = path["new"]
  147. # These have already been assigned
  148. if attention_paths_to_split is not None and new_path in attention_paths_to_split:
  149. continue
  150. # Global renaming happens here
  151. new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
  152. new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
  153. new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
  154. if additional_replacements is not None:
  155. for replacement in additional_replacements:
  156. new_path = new_path.replace(replacement["old"], replacement["new"])
  157. # proj_attn.weight has to be converted from conv 1D to linear
  158. if "proj_attn.weight" in new_path:
  159. checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
  160. else:
  161. checkpoint[new_path] = old_checkpoint[path["old"]]
  162. def conv_attn_to_linear(checkpoint):
  163. keys = list(checkpoint.keys())
  164. attn_keys = ["query.weight", "key.weight", "value.weight"]
  165. for key in keys:
  166. if ".".join(key.split(".")[-2:]) in attn_keys:
  167. if checkpoint[key].ndim > 2:
  168. checkpoint[key] = checkpoint[key][:, :, 0, 0]
  169. elif "proj_attn.weight" in key:
  170. if checkpoint[key].ndim > 2:
  171. checkpoint[key] = checkpoint[key][:, :, 0]
  172. def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
  173. """
  174. Creates a config for the diffusers based on the config of the LDM model.
  175. """
  176. if controlnet:
  177. unet_params = original_config.model.params.control_stage_config.params
  178. else:
  179. unet_params = original_config.model.params.unet_config.params
  180. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  181. block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
  182. down_block_types = []
  183. resolution = 1
  184. for i in range(len(block_out_channels)):
  185. block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
  186. down_block_types.append(block_type)
  187. if i != len(block_out_channels) - 1:
  188. resolution *= 2
  189. up_block_types = []
  190. for i in range(len(block_out_channels)):
  191. block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
  192. up_block_types.append(block_type)
  193. resolution //= 2
  194. vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
  195. head_dim = unet_params.num_heads if "num_heads" in unet_params else None
  196. use_linear_projection = (
  197. unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
  198. )
  199. if use_linear_projection:
  200. # stable diffusion 2-base-512 and 2-768
  201. if head_dim is None:
  202. head_dim = [5, 10, 20, 20]
  203. class_embed_type = None
  204. projection_class_embeddings_input_dim = None
  205. if "num_classes" in unet_params:
  206. if unet_params.num_classes == "sequential":
  207. class_embed_type = "projection"
  208. assert "adm_in_channels" in unet_params
  209. projection_class_embeddings_input_dim = unet_params.adm_in_channels
  210. else:
  211. raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
  212. config = {
  213. "sample_size": image_size // vae_scale_factor,
  214. "in_channels": unet_params.in_channels,
  215. "down_block_types": tuple(down_block_types),
  216. "block_out_channels": tuple(block_out_channels),
  217. "layers_per_block": unet_params.num_res_blocks,
  218. "cross_attention_dim": unet_params.context_dim,
  219. "attention_head_dim": head_dim,
  220. "use_linear_projection": use_linear_projection,
  221. "class_embed_type": class_embed_type,
  222. "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
  223. }
  224. if not controlnet:
  225. config["out_channels"] = unet_params.out_channels
  226. config["up_block_types"] = tuple(up_block_types)
  227. return config
  228. def create_vae_diffusers_config(original_config, image_size: int):
  229. """
  230. Creates a config for the diffusers based on the config of the LDM model.
  231. """
  232. vae_params = original_config.model.params.first_stage_config.params.ddconfig
  233. _ = original_config.model.params.first_stage_config.params.embed_dim
  234. block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
  235. down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
  236. up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
  237. config = {
  238. "sample_size": image_size,
  239. "in_channels": vae_params.in_channels,
  240. "out_channels": vae_params.out_ch,
  241. "down_block_types": tuple(down_block_types),
  242. "up_block_types": tuple(up_block_types),
  243. "block_out_channels": tuple(block_out_channels),
  244. "latent_channels": vae_params.z_channels,
  245. "layers_per_block": vae_params.num_res_blocks,
  246. }
  247. return config
  248. def create_diffusers_schedular(original_config):
  249. schedular = DDIMScheduler(
  250. num_train_timesteps=original_config.model.params.timesteps,
  251. beta_start=original_config.model.params.linear_start,
  252. beta_end=original_config.model.params.linear_end,
  253. beta_schedule="scaled_linear",
  254. )
  255. return schedular
  256. def create_ldm_bert_config(original_config):
  257. bert_params = original_config.model.parms.cond_stage_config.params
  258. config = LDMBertConfig(
  259. d_model=bert_params.n_embed,
  260. encoder_layers=bert_params.n_layer,
  261. encoder_ffn_dim=bert_params.n_embed * 4,
  262. )
  263. return config
  264. def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
  265. """
  266. Takes a state dict and a config, and returns a converted checkpoint.
  267. """
  268. # extract state_dict for UNet
  269. unet_state_dict = {}
  270. keys = list(checkpoint.keys())
  271. if controlnet:
  272. unet_key = "control_model."
  273. else:
  274. unet_key = "model.diffusion_model."
  275. # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
  276. if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
  277. print(f"Checkpoint {path} has both EMA and non-EMA weights.")
  278. print(
  279. "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
  280. " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
  281. )
  282. for key in keys:
  283. if key.startswith("model.diffusion_model"):
  284. flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
  285. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
  286. else:
  287. if sum(k.startswith("model_ema") for k in keys) > 100:
  288. print(
  289. "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
  290. " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
  291. )
  292. for key in keys:
  293. if key.startswith(unet_key):
  294. unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
  295. new_checkpoint = {}
  296. new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
  297. new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
  298. new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
  299. new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
  300. if config["class_embed_type"] is None:
  301. # No parameters to port
  302. ...
  303. elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
  304. new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
  305. new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
  306. new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
  307. new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
  308. else:
  309. raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
  310. new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
  311. new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
  312. if not controlnet:
  313. new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
  314. new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
  315. new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
  316. new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
  317. # Retrieves the keys for the input blocks only
  318. num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
  319. input_blocks = {
  320. layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
  321. for layer_id in range(num_input_blocks)
  322. }
  323. # Retrieves the keys for the middle blocks only
  324. num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
  325. middle_blocks = {
  326. layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
  327. for layer_id in range(num_middle_blocks)
  328. }
  329. # Retrieves the keys for the output blocks only
  330. num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
  331. output_blocks = {
  332. layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
  333. for layer_id in range(num_output_blocks)
  334. }
  335. for i in range(1, num_input_blocks):
  336. block_id = (i - 1) // (config["layers_per_block"] + 1)
  337. layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
  338. resnets = [
  339. key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
  340. ]
  341. attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
  342. if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
  343. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
  344. f"input_blocks.{i}.0.op.weight"
  345. )
  346. new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
  347. f"input_blocks.{i}.0.op.bias"
  348. )
  349. paths = renew_resnet_paths(resnets)
  350. meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
  351. assign_to_checkpoint(
  352. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  353. )
  354. if len(attentions):
  355. paths = renew_attention_paths(attentions)
  356. meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
  357. assign_to_checkpoint(
  358. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  359. )
  360. resnet_0 = middle_blocks[0]
  361. attentions = middle_blocks[1]
  362. resnet_1 = middle_blocks[2]
  363. resnet_0_paths = renew_resnet_paths(resnet_0)
  364. assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
  365. resnet_1_paths = renew_resnet_paths(resnet_1)
  366. assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
  367. attentions_paths = renew_attention_paths(attentions)
  368. meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
  369. assign_to_checkpoint(
  370. attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  371. )
  372. for i in range(num_output_blocks):
  373. block_id = i // (config["layers_per_block"] + 1)
  374. layer_in_block_id = i % (config["layers_per_block"] + 1)
  375. output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
  376. output_block_list = {}
  377. for layer in output_block_layers:
  378. layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
  379. if layer_id in output_block_list:
  380. output_block_list[layer_id].append(layer_name)
  381. else:
  382. output_block_list[layer_id] = [layer_name]
  383. if len(output_block_list) > 1:
  384. resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
  385. attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
  386. resnet_0_paths = renew_resnet_paths(resnets)
  387. paths = renew_resnet_paths(resnets)
  388. meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
  389. assign_to_checkpoint(
  390. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  391. )
  392. output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
  393. if ["conv.bias", "conv.weight"] in output_block_list.values():
  394. index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
  395. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
  396. f"output_blocks.{i}.{index}.conv.weight"
  397. ]
  398. new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
  399. f"output_blocks.{i}.{index}.conv.bias"
  400. ]
  401. # Clear attentions as they have been attributed above.
  402. if len(attentions) == 2:
  403. attentions = []
  404. if len(attentions):
  405. paths = renew_attention_paths(attentions)
  406. meta_path = {
  407. "old": f"output_blocks.{i}.1",
  408. "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
  409. }
  410. assign_to_checkpoint(
  411. paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
  412. )
  413. else:
  414. resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
  415. for path in resnet_0_paths:
  416. old_path = ".".join(["output_blocks", str(i), path["old"]])
  417. new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
  418. new_checkpoint[new_path] = unet_state_dict[old_path]
  419. if controlnet:
  420. # conditioning embedding
  421. orig_index = 0
  422. new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
  423. f"input_hint_block.{orig_index}.weight"
  424. )
  425. new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
  426. f"input_hint_block.{orig_index}.bias"
  427. )
  428. orig_index += 2
  429. diffusers_index = 0
  430. while diffusers_index < 6:
  431. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
  432. f"input_hint_block.{orig_index}.weight"
  433. )
  434. new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
  435. f"input_hint_block.{orig_index}.bias"
  436. )
  437. diffusers_index += 1
  438. orig_index += 2
  439. new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
  440. f"input_hint_block.{orig_index}.weight"
  441. )
  442. new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
  443. f"input_hint_block.{orig_index}.bias"
  444. )
  445. # down blocks
  446. for i in range(num_input_blocks):
  447. new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
  448. new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
  449. # mid block
  450. new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
  451. new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
  452. return new_checkpoint
  453. def convert_ldm_vae_checkpoint(checkpoint, config):
  454. # extract state dict for VAE
  455. vae_state_dict = {}
  456. vae_key = "first_stage_model."
  457. keys = list(checkpoint.keys())
  458. for key in keys:
  459. if key.startswith(vae_key):
  460. vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
  461. new_checkpoint = {}
  462. new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
  463. new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
  464. new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
  465. new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
  466. new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
  467. new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
  468. new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
  469. new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
  470. new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
  471. new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
  472. new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
  473. new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
  474. new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
  475. new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
  476. new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
  477. new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
  478. # Retrieves the keys for the encoder down blocks only
  479. num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
  480. down_blocks = {
  481. layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
  482. }
  483. # Retrieves the keys for the decoder up blocks only
  484. num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
  485. up_blocks = {
  486. layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
  487. }
  488. for i in range(num_down_blocks):
  489. resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
  490. if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
  491. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
  492. f"encoder.down.{i}.downsample.conv.weight"
  493. )
  494. new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
  495. f"encoder.down.{i}.downsample.conv.bias"
  496. )
  497. paths = renew_vae_resnet_paths(resnets)
  498. meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
  499. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  500. mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
  501. num_mid_res_blocks = 2
  502. for i in range(1, num_mid_res_blocks + 1):
  503. resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
  504. paths = renew_vae_resnet_paths(resnets)
  505. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  506. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  507. mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
  508. paths = renew_vae_attention_paths(mid_attentions)
  509. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  510. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  511. conv_attn_to_linear(new_checkpoint)
  512. for i in range(num_up_blocks):
  513. block_id = num_up_blocks - 1 - i
  514. resnets = [
  515. key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
  516. ]
  517. if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
  518. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
  519. f"decoder.up.{block_id}.upsample.conv.weight"
  520. ]
  521. new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
  522. f"decoder.up.{block_id}.upsample.conv.bias"
  523. ]
  524. paths = renew_vae_resnet_paths(resnets)
  525. meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
  526. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  527. mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
  528. num_mid_res_blocks = 2
  529. for i in range(1, num_mid_res_blocks + 1):
  530. resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
  531. paths = renew_vae_resnet_paths(resnets)
  532. meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
  533. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  534. mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
  535. paths = renew_vae_attention_paths(mid_attentions)
  536. meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
  537. assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
  538. conv_attn_to_linear(new_checkpoint)
  539. return new_checkpoint
  540. def convert_ldm_bert_checkpoint(checkpoint, config):
  541. def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
  542. hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
  543. hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
  544. hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
  545. hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
  546. hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
  547. def _copy_linear(hf_linear, pt_linear):
  548. hf_linear.weight = pt_linear.weight
  549. hf_linear.bias = pt_linear.bias
  550. def _copy_layer(hf_layer, pt_layer):
  551. # copy layer norms
  552. _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
  553. _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
  554. # copy attn
  555. _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
  556. # copy MLP
  557. pt_mlp = pt_layer[1][1]
  558. _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
  559. _copy_linear(hf_layer.fc2, pt_mlp.net[2])
  560. def _copy_layers(hf_layers, pt_layers):
  561. for i, hf_layer in enumerate(hf_layers):
  562. if i != 0:
  563. i += i
  564. pt_layer = pt_layers[i : i + 2]
  565. _copy_layer(hf_layer, pt_layer)
  566. hf_model = LDMBertModel(config).eval()
  567. # copy embeds
  568. hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
  569. hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
  570. # copy layer norm
  571. _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
  572. # copy hidden layers
  573. _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
  574. _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
  575. return hf_model
  576. def convert_ldm_clip_checkpoint(checkpoint):
  577. # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
  578. text_model = CLIPTextModel.from_pretrained("/mnt/petrelfs/guoyuwei/projects/huggingface/clip-vit-large-patch14")
  579. keys = list(checkpoint.keys())
  580. text_model_dict = {}
  581. for key in keys:
  582. if key.startswith("cond_stage_model.transformer"):
  583. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  584. text_model.load_state_dict(text_model_dict)
  585. return text_model
  586. textenc_conversion_lst = [
  587. ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
  588. ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
  589. ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
  590. ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
  591. ]
  592. textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
  593. textenc_transformer_conversion_lst = [
  594. # (stable-diffusion, HF Diffusers)
  595. ("resblocks.", "text_model.encoder.layers."),
  596. ("ln_1", "layer_norm1"),
  597. ("ln_2", "layer_norm2"),
  598. (".c_fc.", ".fc1."),
  599. (".c_proj.", ".fc2."),
  600. (".attn", ".self_attn"),
  601. ("ln_final.", "transformer.text_model.final_layer_norm."),
  602. ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
  603. ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
  604. ]
  605. protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
  606. textenc_pattern = re.compile("|".join(protected.keys()))
  607. def convert_paint_by_example_checkpoint(checkpoint):
  608. config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
  609. model = PaintByExampleImageEncoder(config)
  610. keys = list(checkpoint.keys())
  611. text_model_dict = {}
  612. for key in keys:
  613. if key.startswith("cond_stage_model.transformer"):
  614. text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
  615. # load clip vision
  616. model.model.load_state_dict(text_model_dict)
  617. # load mapper
  618. keys_mapper = {
  619. k[len("cond_stage_model.mapper.res") :]: v
  620. for k, v in checkpoint.items()
  621. if k.startswith("cond_stage_model.mapper")
  622. }
  623. MAPPING = {
  624. "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
  625. "attn.c_proj": ["attn1.to_out.0"],
  626. "ln_1": ["norm1"],
  627. "ln_2": ["norm3"],
  628. "mlp.c_fc": ["ff.net.0.proj"],
  629. "mlp.c_proj": ["ff.net.2"],
  630. }
  631. mapped_weights = {}
  632. for key, value in keys_mapper.items():
  633. prefix = key[: len("blocks.i")]
  634. suffix = key.split(prefix)[-1].split(".")[-1]
  635. name = key.split(prefix)[-1].split(suffix)[0][1:-1]
  636. mapped_names = MAPPING[name]
  637. num_splits = len(mapped_names)
  638. for i, mapped_name in enumerate(mapped_names):
  639. new_name = ".".join([prefix, mapped_name, suffix])
  640. shape = value.shape[0] // num_splits
  641. mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
  642. model.mapper.load_state_dict(mapped_weights)
  643. # load final layer norm
  644. model.final_layer_norm.load_state_dict(
  645. {
  646. "bias": checkpoint["cond_stage_model.final_ln.bias"],
  647. "weight": checkpoint["cond_stage_model.final_ln.weight"],
  648. }
  649. )
  650. # load final proj
  651. model.proj_out.load_state_dict(
  652. {
  653. "bias": checkpoint["proj_out.bias"],
  654. "weight": checkpoint["proj_out.weight"],
  655. }
  656. )
  657. # load uncond vector
  658. model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
  659. return model
  660. def convert_open_clip_checkpoint(checkpoint):
  661. text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
  662. keys = list(checkpoint.keys())
  663. text_model_dict = {}
  664. if "cond_stage_model.model.text_projection" in checkpoint:
  665. d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
  666. else:
  667. d_model = 1024
  668. text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
  669. for key in keys:
  670. if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
  671. continue
  672. if key in textenc_conversion_map:
  673. text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
  674. if key.startswith("cond_stage_model.model.transformer."):
  675. new_key = key[len("cond_stage_model.model.transformer.") :]
  676. if new_key.endswith(".in_proj_weight"):
  677. new_key = new_key[: -len(".in_proj_weight")]
  678. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  679. text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
  680. text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
  681. text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
  682. elif new_key.endswith(".in_proj_bias"):
  683. new_key = new_key[: -len(".in_proj_bias")]
  684. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  685. text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
  686. text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
  687. text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
  688. else:
  689. new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
  690. text_model_dict[new_key] = checkpoint[key]
  691. text_model.load_state_dict(text_model_dict)
  692. return text_model
  693. def stable_unclip_image_encoder(original_config):
  694. """
  695. Returns the image processor and clip image encoder for the img2img unclip pipeline.
  696. We currently know of two types of stable unclip models which separately use the clip and the openclip image
  697. encoders.
  698. """
  699. image_embedder_config = original_config.model.params.embedder_config
  700. sd_clip_image_embedder_class = image_embedder_config.target
  701. sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
  702. if sd_clip_image_embedder_class == "ClipImageEmbedder":
  703. clip_model_name = image_embedder_config.params.model
  704. if clip_model_name == "ViT-L/14":
  705. feature_extractor = CLIPImageProcessor()
  706. image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
  707. else:
  708. raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
  709. elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
  710. feature_extractor = CLIPImageProcessor()
  711. image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
  712. else:
  713. raise NotImplementedError(
  714. f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
  715. )
  716. return feature_extractor, image_encoder
  717. def stable_unclip_image_noising_components(
  718. original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
  719. ):
  720. """
  721. Returns the noising components for the img2img and txt2img unclip pipelines.
  722. Converts the stability noise augmentor into
  723. 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
  724. 2. a `DDPMScheduler` for holding the noise schedule
  725. If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
  726. """
  727. noise_aug_config = original_config.model.params.noise_aug_config
  728. noise_aug_class = noise_aug_config.target
  729. noise_aug_class = noise_aug_class.split(".")[-1]
  730. if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
  731. noise_aug_config = noise_aug_config.params
  732. embedding_dim = noise_aug_config.timestep_dim
  733. max_noise_level = noise_aug_config.noise_schedule_config.timesteps
  734. beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
  735. image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
  736. image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
  737. if "clip_stats_path" in noise_aug_config:
  738. if clip_stats_path is None:
  739. raise ValueError("This stable unclip config requires a `clip_stats_path`")
  740. clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
  741. clip_mean = clip_mean[None, :]
  742. clip_std = clip_std[None, :]
  743. clip_stats_state_dict = {
  744. "mean": clip_mean,
  745. "std": clip_std,
  746. }
  747. image_normalizer.load_state_dict(clip_stats_state_dict)
  748. else:
  749. raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
  750. return image_normalizer, image_noising_scheduler
  751. def convert_controlnet_checkpoint(
  752. checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
  753. ):
  754. ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
  755. ctrlnet_config["upcast_attention"] = upcast_attention
  756. ctrlnet_config.pop("sample_size")
  757. controlnet_model = ControlNetModel(**ctrlnet_config)
  758. converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
  759. checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
  760. )
  761. controlnet_model.load_state_dict(converted_ctrl_checkpoint)
  762. return controlnet_model
  763. # def download_from_original_stable_diffusion_ckpt(
  764. # checkpoint_path: str,
  765. # original_config_file: str = None,
  766. # image_size: int = 512,
  767. # prediction_type: str = None,
  768. # model_type: str = None,
  769. # extract_ema: bool = False,
  770. # scheduler_type: str = "pndm",
  771. # num_in_channels: Optional[int] = None,
  772. # upcast_attention: Optional[bool] = None,
  773. # device: str = None,
  774. # from_safetensors: bool = False,
  775. # stable_unclip: Optional[str] = None,
  776. # stable_unclip_prior: Optional[str] = None,
  777. # clip_stats_path: Optional[str] = None,
  778. # controlnet: Optional[bool] = None,
  779. # load_safety_checker: bool = True,
  780. # pipeline_class: DiffusionPipeline = None,
  781. # ) -> DiffusionPipeline:
  782. # """
  783. # Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
  784. # config file.
  785. # Although many of the arguments can be automatically inferred, some of these rely on brittle checks against the
  786. # global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
  787. # recommended that you override the default values and/or supply an `original_config_file` wherever possible.
  788. # Args:
  789. # checkpoint_path (`str`): Path to `.ckpt` file.
  790. # original_config_file (`str`):
  791. # Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
  792. # inferred by looking for a key that only exists in SD2.0 models.
  793. # image_size (`int`, *optional*, defaults to 512):
  794. # The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
  795. # Base. Use 768 for Stable Diffusion v2.
  796. # prediction_type (`str`, *optional*):
  797. # The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
  798. # Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
  799. # num_in_channels (`int`, *optional*, defaults to None):
  800. # The number of input channels. If `None`, it will be automatically inferred.
  801. # scheduler_type (`str`, *optional*, defaults to 'pndm'):
  802. # Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
  803. # "ddim"]`.
  804. # model_type (`str`, *optional*, defaults to `None`):
  805. # The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
  806. # "FrozenCLIPEmbedder", "PaintByExample"]`.
  807. # is_img2img (`bool`, *optional*, defaults to `False`):
  808. # Whether the model should be loaded as an img2img pipeline.
  809. # extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
  810. # checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
  811. # `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
  812. # inference. Non-EMA weights are usually better to continue fine-tuning.
  813. # upcast_attention (`bool`, *optional*, defaults to `None`):
  814. # Whether the attention computation should always be upcasted. This is necessary when running stable
  815. # diffusion 2.1.
  816. # device (`str`, *optional*, defaults to `None`):
  817. # The device to use. Pass `None` to determine automatically.
  818. # from_safetensors (`str`, *optional*, defaults to `False`):
  819. # If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
  820. # load_safety_checker (`bool`, *optional*, defaults to `True`):
  821. # Whether to load the safety checker or not. Defaults to `True`.
  822. # pipeline_class (`str`, *optional*, defaults to `None`):
  823. # The pipeline class to use. Pass `None` to determine automatically.
  824. # return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
  825. # """
  826. # # import pipelines here to avoid circular import error when using from_ckpt method
  827. # from diffusers import (
  828. # LDMTextToImagePipeline,
  829. # PaintByExamplePipeline,
  830. # StableDiffusionControlNetPipeline,
  831. # StableDiffusionPipeline,
  832. # StableUnCLIPImg2ImgPipeline,
  833. # StableUnCLIPPipeline,
  834. # )
  835. # if pipeline_class is None:
  836. # pipeline_class = StableDiffusionPipeline
  837. # if prediction_type == "v-prediction":
  838. # prediction_type = "v_prediction"
  839. # if not is_omegaconf_available():
  840. # raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
  841. # from omegaconf import OmegaConf
  842. # if from_safetensors:
  843. # if not is_safetensors_available():
  844. # raise ValueError(BACKENDS_MAPPING["safetensors"][1])
  845. # from safetensors import safe_open
  846. # checkpoint = {}
  847. # with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
  848. # for key in f.keys():
  849. # checkpoint[key] = f.get_tensor(key)
  850. # else:
  851. # if device is None:
  852. # device = "cuda" if torch.cuda.is_available() else "cpu"
  853. # checkpoint = torch.load(checkpoint_path, map_location=device)
  854. # else:
  855. # checkpoint = torch.load(checkpoint_path, map_location=device)
  856. # # Sometimes models don't have the global_step item
  857. # if "global_step" in checkpoint:
  858. # global_step = checkpoint["global_step"]
  859. # else:
  860. # print("global_step key not found in model")
  861. # global_step = None
  862. # # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
  863. # # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
  864. # while "state_dict" in checkpoint:
  865. # checkpoint = checkpoint["state_dict"]
  866. # if original_config_file is None:
  867. # key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
  868. # # model_type = "v1"
  869. # config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
  870. # if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
  871. # # model_type = "v2"
  872. # config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
  873. # if global_step == 110000:
  874. # # v2.1 needs to upcast attention
  875. # upcast_attention = True
  876. # original_config_file = BytesIO(requests.get(config_url).content)
  877. # original_config = OmegaConf.load(original_config_file)
  878. # if num_in_channels is not None:
  879. # original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
  880. # if (
  881. # "parameterization" in original_config["model"]["params"]
  882. # and original_config["model"]["params"]["parameterization"] == "v"
  883. # ):
  884. # if prediction_type is None:
  885. # # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
  886. # # as it relies on a brittle global step parameter here
  887. # prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
  888. # if image_size is None:
  889. # # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
  890. # # as it relies on a brittle global step parameter here
  891. # image_size = 512 if global_step == 875000 else 768
  892. # else:
  893. # if prediction_type is None:
  894. # prediction_type = "epsilon"
  895. # if image_size is None:
  896. # image_size = 512
  897. # if controlnet is None:
  898. # controlnet = "control_stage_config" in original_config.model.params
  899. # if controlnet:
  900. # controlnet_model = convert_controlnet_checkpoint(
  901. # checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
  902. # )
  903. # num_train_timesteps = original_config.model.params.timesteps
  904. # beta_start = original_config.model.params.linear_start
  905. # beta_end = original_config.model.params.linear_end
  906. # scheduler = DDIMScheduler(
  907. # beta_end=beta_end,
  908. # beta_schedule="scaled_linear",
  909. # beta_start=beta_start,
  910. # num_train_timesteps=num_train_timesteps,
  911. # steps_offset=1,
  912. # clip_sample=False,
  913. # set_alpha_to_one=False,
  914. # prediction_type=prediction_type,
  915. # )
  916. # # make sure scheduler works correctly with DDIM
  917. # scheduler.register_to_config(clip_sample=False)
  918. # if scheduler_type == "pndm":
  919. # config = dict(scheduler.config)
  920. # config["skip_prk_steps"] = True
  921. # scheduler = PNDMScheduler.from_config(config)
  922. # elif scheduler_type == "lms":
  923. # scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
  924. # elif scheduler_type == "heun":
  925. # scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
  926. # elif scheduler_type == "euler":
  927. # scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
  928. # elif scheduler_type == "euler-ancestral":
  929. # scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
  930. # elif scheduler_type == "dpm":
  931. # scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
  932. # elif scheduler_type == "ddim":
  933. # scheduler = scheduler
  934. # else:
  935. # raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
  936. # # Convert the UNet2DConditionModel model.
  937. # unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
  938. # unet_config["upcast_attention"] = upcast_attention
  939. # unet = UNet2DConditionModel(**unet_config)
  940. # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
  941. # checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
  942. # )
  943. # unet.load_state_dict(converted_unet_checkpoint)
  944. # # Convert the VAE model.
  945. # vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
  946. # converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
  947. # vae = AutoencoderKL(**vae_config)
  948. # vae.load_state_dict(converted_vae_checkpoint)
  949. # # Convert the text model.
  950. # if model_type is None:
  951. # model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
  952. # logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
  953. # if model_type == "FrozenOpenCLIPEmbedder":
  954. # text_model = convert_open_clip_checkpoint(checkpoint)
  955. # tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
  956. # if stable_unclip is None:
  957. # if controlnet:
  958. # pipe = StableDiffusionControlNetPipeline(
  959. # vae=vae,
  960. # text_encoder=text_model,
  961. # tokenizer=tokenizer,
  962. # unet=unet,
  963. # scheduler=scheduler,
  964. # controlnet=controlnet_model,
  965. # safety_checker=None,
  966. # feature_extractor=None,
  967. # requires_safety_checker=False,
  968. # )
  969. # else:
  970. # pipe = pipeline_class(
  971. # vae=vae,
  972. # text_encoder=text_model,
  973. # tokenizer=tokenizer,
  974. # unet=unet,
  975. # scheduler=scheduler,
  976. # safety_checker=None,
  977. # feature_extractor=None,
  978. # requires_safety_checker=False,
  979. # )
  980. # else:
  981. # image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
  982. # original_config, clip_stats_path=clip_stats_path, device=device
  983. # )
  984. # if stable_unclip == "img2img":
  985. # feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
  986. # pipe = StableUnCLIPImg2ImgPipeline(
  987. # # image encoding components
  988. # feature_extractor=feature_extractor,
  989. # image_encoder=image_encoder,
  990. # # image noising components
  991. # image_normalizer=image_normalizer,
  992. # image_noising_scheduler=image_noising_scheduler,
  993. # # regular denoising components
  994. # tokenizer=tokenizer,
  995. # text_encoder=text_model,
  996. # unet=unet,
  997. # scheduler=scheduler,
  998. # # vae
  999. # vae=vae,
  1000. # )
  1001. # elif stable_unclip == "txt2img":
  1002. # if stable_unclip_prior is None or stable_unclip_prior == "karlo":
  1003. # karlo_model = "kakaobrain/karlo-v1-alpha"
  1004. # prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
  1005. # prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  1006. # prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
  1007. # prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
  1008. # prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
  1009. # else:
  1010. # raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
  1011. # pipe = StableUnCLIPPipeline(
  1012. # # prior components
  1013. # prior_tokenizer=prior_tokenizer,
  1014. # prior_text_encoder=prior_text_model,
  1015. # prior=prior,
  1016. # prior_scheduler=prior_scheduler,
  1017. # # image noising components
  1018. # image_normalizer=image_normalizer,
  1019. # image_noising_scheduler=image_noising_scheduler,
  1020. # # regular denoising components
  1021. # tokenizer=tokenizer,
  1022. # text_encoder=text_model,
  1023. # unet=unet,
  1024. # scheduler=scheduler,
  1025. # # vae
  1026. # vae=vae,
  1027. # )
  1028. # else:
  1029. # raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
  1030. # elif model_type == "PaintByExample":
  1031. # vision_model = convert_paint_by_example_checkpoint(checkpoint)
  1032. # tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  1033. # feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
  1034. # pipe = PaintByExamplePipeline(
  1035. # vae=vae,
  1036. # image_encoder=vision_model,
  1037. # unet=unet,
  1038. # scheduler=scheduler,
  1039. # safety_checker=None,
  1040. # feature_extractor=feature_extractor,
  1041. # )
  1042. # elif model_type == "FrozenCLIPEmbedder":
  1043. # text_model = convert_ldm_clip_checkpoint(checkpoint)
  1044. # # tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
  1045. # tokenizer = CLIPTokenizer.from_pretrained("/mnt/petrelfs/guoyuwei/projects/huggingface/clip-vit-large-patch14")
  1046. # # if load_safety_checker:
  1047. # if False:
  1048. # safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
  1049. # feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
  1050. # else:
  1051. # safety_checker = None
  1052. # feature_extractor = None
  1053. # if controlnet:
  1054. # pipe = StableDiffusionControlNetPipeline(
  1055. # vae=vae,
  1056. # text_encoder=text_model,
  1057. # tokenizer=tokenizer,
  1058. # unet=unet,
  1059. # controlnet=controlnet_model,
  1060. # scheduler=scheduler,
  1061. # safety_checker=safety_checker,
  1062. # feature_extractor=feature_extractor,
  1063. # )
  1064. # else:
  1065. # pipe = pipeline_class(
  1066. # vae=vae,
  1067. # text_encoder=text_model,
  1068. # tokenizer=tokenizer,
  1069. # unet=unet,
  1070. # scheduler=scheduler,
  1071. # safety_checker=safety_checker,
  1072. # feature_extractor=feature_extractor,
  1073. # )
  1074. # else:
  1075. # text_config = create_ldm_bert_config(original_config)
  1076. # text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
  1077. # tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
  1078. # pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
  1079. # return pipe
  1080. # def download_controlnet_from_original_ckpt(
  1081. # checkpoint_path: str,
  1082. # original_config_file: str,
  1083. # image_size: int = 512,
  1084. # extract_ema: bool = False,
  1085. # num_in_channels: Optional[int] = None,
  1086. # upcast_attention: Optional[bool] = None,
  1087. # device: str = None,
  1088. # from_safetensors: bool = False,
  1089. # ) -> DiffusionPipeline:
  1090. # if not is_omegaconf_available():
  1091. # raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
  1092. # from omegaconf import OmegaConf
  1093. # if from_safetensors:
  1094. # if not is_safetensors_available():
  1095. # raise ValueError(BACKENDS_MAPPING["safetensors"][1])
  1096. # from safetensors import safe_open
  1097. # checkpoint = {}
  1098. # with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
  1099. # for key in f.keys():
  1100. # checkpoint[key] = f.get_tensor(key)
  1101. # else:
  1102. # if device is None:
  1103. # device = "cuda" if torch.cuda.is_available() else "cpu"
  1104. # checkpoint = torch.load(checkpoint_path, map_location=device)
  1105. # else:
  1106. # checkpoint = torch.load(checkpoint_path, map_location=device)
  1107. # # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
  1108. # # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
  1109. # while "state_dict" in checkpoint:
  1110. # checkpoint = checkpoint["state_dict"]
  1111. # original_config = OmegaConf.load(original_config_file)
  1112. # if num_in_channels is not None:
  1113. # original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
  1114. # if "control_stage_config" not in original_config.model.params:
  1115. # raise ValueError("`control_stage_config` not present in original config")
  1116. # controlnet_model = convert_controlnet_checkpoint(
  1117. # checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
  1118. # )
  1119. # return controlnet_model