unet_blocks.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
  2. import torch
  3. from torch import nn
  4. from .attention import Transformer3DModel
  5. from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
  6. from .motion_module import get_motion_module
  7. import pdb
  8. def get_down_block(
  9. down_block_type,
  10. num_layers,
  11. in_channels,
  12. out_channels,
  13. temb_channels,
  14. add_downsample,
  15. resnet_eps,
  16. resnet_act_fn,
  17. attn_num_head_channels,
  18. resnet_groups=None,
  19. cross_attention_dim=None,
  20. downsample_padding=None,
  21. dual_cross_attention=False,
  22. use_linear_projection=False,
  23. only_cross_attention=False,
  24. upcast_attention=False,
  25. resnet_time_scale_shift="default",
  26. unet_use_cross_frame_attention=None,
  27. unet_use_temporal_attention=None,
  28. use_inflated_groupnorm=None,
  29. use_motion_module=None,
  30. motion_module_type=None,
  31. motion_module_kwargs=None,
  32. ):
  33. down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
  34. if down_block_type == "DownBlock3D":
  35. return DownBlock3D(
  36. num_layers=num_layers,
  37. in_channels=in_channels,
  38. out_channels=out_channels,
  39. temb_channels=temb_channels,
  40. add_downsample=add_downsample,
  41. resnet_eps=resnet_eps,
  42. resnet_act_fn=resnet_act_fn,
  43. resnet_groups=resnet_groups,
  44. downsample_padding=downsample_padding,
  45. resnet_time_scale_shift=resnet_time_scale_shift,
  46. use_inflated_groupnorm=use_inflated_groupnorm,
  47. use_motion_module=use_motion_module,
  48. motion_module_type=motion_module_type,
  49. motion_module_kwargs=motion_module_kwargs,
  50. )
  51. elif down_block_type == "CrossAttnDownBlock3D":
  52. if cross_attention_dim is None:
  53. raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
  54. return CrossAttnDownBlock3D(
  55. num_layers=num_layers,
  56. in_channels=in_channels,
  57. out_channels=out_channels,
  58. temb_channels=temb_channels,
  59. add_downsample=add_downsample,
  60. resnet_eps=resnet_eps,
  61. resnet_act_fn=resnet_act_fn,
  62. resnet_groups=resnet_groups,
  63. downsample_padding=downsample_padding,
  64. cross_attention_dim=cross_attention_dim,
  65. attn_num_head_channels=attn_num_head_channels,
  66. dual_cross_attention=dual_cross_attention,
  67. use_linear_projection=use_linear_projection,
  68. only_cross_attention=only_cross_attention,
  69. upcast_attention=upcast_attention,
  70. resnet_time_scale_shift=resnet_time_scale_shift,
  71. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  72. unet_use_temporal_attention=unet_use_temporal_attention,
  73. use_inflated_groupnorm=use_inflated_groupnorm,
  74. use_motion_module=use_motion_module,
  75. motion_module_type=motion_module_type,
  76. motion_module_kwargs=motion_module_kwargs,
  77. )
  78. raise ValueError(f"{down_block_type} does not exist.")
  79. def get_up_block(
  80. up_block_type,
  81. num_layers,
  82. in_channels,
  83. out_channels,
  84. prev_output_channel,
  85. temb_channels,
  86. add_upsample,
  87. resnet_eps,
  88. resnet_act_fn,
  89. attn_num_head_channels,
  90. resnet_groups=None,
  91. cross_attention_dim=None,
  92. dual_cross_attention=False,
  93. use_linear_projection=False,
  94. only_cross_attention=False,
  95. upcast_attention=False,
  96. resnet_time_scale_shift="default",
  97. unet_use_cross_frame_attention=None,
  98. unet_use_temporal_attention=None,
  99. use_inflated_groupnorm=None,
  100. use_motion_module=None,
  101. motion_module_type=None,
  102. motion_module_kwargs=None,
  103. ):
  104. up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
  105. if up_block_type == "UpBlock3D":
  106. return UpBlock3D(
  107. num_layers=num_layers,
  108. in_channels=in_channels,
  109. out_channels=out_channels,
  110. prev_output_channel=prev_output_channel,
  111. temb_channels=temb_channels,
  112. add_upsample=add_upsample,
  113. resnet_eps=resnet_eps,
  114. resnet_act_fn=resnet_act_fn,
  115. resnet_groups=resnet_groups,
  116. resnet_time_scale_shift=resnet_time_scale_shift,
  117. use_inflated_groupnorm=use_inflated_groupnorm,
  118. use_motion_module=use_motion_module,
  119. motion_module_type=motion_module_type,
  120. motion_module_kwargs=motion_module_kwargs,
  121. )
  122. elif up_block_type == "CrossAttnUpBlock3D":
  123. if cross_attention_dim is None:
  124. raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
  125. return CrossAttnUpBlock3D(
  126. num_layers=num_layers,
  127. in_channels=in_channels,
  128. out_channels=out_channels,
  129. prev_output_channel=prev_output_channel,
  130. temb_channels=temb_channels,
  131. add_upsample=add_upsample,
  132. resnet_eps=resnet_eps,
  133. resnet_act_fn=resnet_act_fn,
  134. resnet_groups=resnet_groups,
  135. cross_attention_dim=cross_attention_dim,
  136. attn_num_head_channels=attn_num_head_channels,
  137. dual_cross_attention=dual_cross_attention,
  138. use_linear_projection=use_linear_projection,
  139. only_cross_attention=only_cross_attention,
  140. upcast_attention=upcast_attention,
  141. resnet_time_scale_shift=resnet_time_scale_shift,
  142. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  143. unet_use_temporal_attention=unet_use_temporal_attention,
  144. use_inflated_groupnorm=use_inflated_groupnorm,
  145. use_motion_module=use_motion_module,
  146. motion_module_type=motion_module_type,
  147. motion_module_kwargs=motion_module_kwargs,
  148. )
  149. raise ValueError(f"{up_block_type} does not exist.")
  150. class UNetMidBlock3DCrossAttn(nn.Module):
  151. def __init__(
  152. self,
  153. in_channels: int,
  154. temb_channels: int,
  155. dropout: float = 0.0,
  156. num_layers: int = 1,
  157. resnet_eps: float = 1e-6,
  158. resnet_time_scale_shift: str = "default",
  159. resnet_act_fn: str = "swish",
  160. resnet_groups: int = 32,
  161. resnet_pre_norm: bool = True,
  162. attn_num_head_channels=1,
  163. output_scale_factor=1.0,
  164. cross_attention_dim=1280,
  165. dual_cross_attention=False,
  166. use_linear_projection=False,
  167. upcast_attention=False,
  168. unet_use_cross_frame_attention=None,
  169. unet_use_temporal_attention=None,
  170. use_inflated_groupnorm=None,
  171. use_motion_module=None,
  172. motion_module_type=None,
  173. motion_module_kwargs=None,
  174. ):
  175. super().__init__()
  176. self.has_cross_attention = True
  177. self.attn_num_head_channels = attn_num_head_channels
  178. resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
  179. # there is always at least one resnet
  180. resnets = [
  181. ResnetBlock3D(
  182. in_channels=in_channels,
  183. out_channels=in_channels,
  184. temb_channels=temb_channels,
  185. eps=resnet_eps,
  186. groups=resnet_groups,
  187. dropout=dropout,
  188. time_embedding_norm=resnet_time_scale_shift,
  189. non_linearity=resnet_act_fn,
  190. output_scale_factor=output_scale_factor,
  191. pre_norm=resnet_pre_norm,
  192. use_inflated_groupnorm=use_inflated_groupnorm,
  193. )
  194. ]
  195. attentions = []
  196. motion_modules = []
  197. for _ in range(num_layers):
  198. if dual_cross_attention:
  199. raise NotImplementedError
  200. attentions.append(
  201. Transformer3DModel(
  202. attn_num_head_channels,
  203. in_channels // attn_num_head_channels,
  204. in_channels=in_channels,
  205. num_layers=1,
  206. cross_attention_dim=cross_attention_dim,
  207. norm_num_groups=resnet_groups,
  208. use_linear_projection=use_linear_projection,
  209. upcast_attention=upcast_attention,
  210. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  211. unet_use_temporal_attention=unet_use_temporal_attention,
  212. )
  213. )
  214. motion_modules.append(
  215. get_motion_module(
  216. in_channels=in_channels,
  217. motion_module_type=motion_module_type,
  218. motion_module_kwargs=motion_module_kwargs,
  219. ) if use_motion_module else None
  220. )
  221. resnets.append(
  222. ResnetBlock3D(
  223. in_channels=in_channels,
  224. out_channels=in_channels,
  225. temb_channels=temb_channels,
  226. eps=resnet_eps,
  227. groups=resnet_groups,
  228. dropout=dropout,
  229. time_embedding_norm=resnet_time_scale_shift,
  230. non_linearity=resnet_act_fn,
  231. output_scale_factor=output_scale_factor,
  232. pre_norm=resnet_pre_norm,
  233. use_inflated_groupnorm=use_inflated_groupnorm,
  234. )
  235. )
  236. self.attentions = nn.ModuleList(attentions)
  237. self.resnets = nn.ModuleList(resnets)
  238. self.motion_modules = nn.ModuleList(motion_modules)
  239. def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
  240. hidden_states = self.resnets[0](hidden_states, temb)
  241. for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
  242. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  243. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  244. hidden_states = resnet(hidden_states, temb)
  245. return hidden_states
  246. class CrossAttnDownBlock3D(nn.Module):
  247. def __init__(
  248. self,
  249. in_channels: int,
  250. out_channels: int,
  251. temb_channels: int,
  252. dropout: float = 0.0,
  253. num_layers: int = 1,
  254. resnet_eps: float = 1e-6,
  255. resnet_time_scale_shift: str = "default",
  256. resnet_act_fn: str = "swish",
  257. resnet_groups: int = 32,
  258. resnet_pre_norm: bool = True,
  259. attn_num_head_channels=1,
  260. cross_attention_dim=1280,
  261. output_scale_factor=1.0,
  262. downsample_padding=1,
  263. add_downsample=True,
  264. dual_cross_attention=False,
  265. use_linear_projection=False,
  266. only_cross_attention=False,
  267. upcast_attention=False,
  268. unet_use_cross_frame_attention=None,
  269. unet_use_temporal_attention=None,
  270. use_inflated_groupnorm=None,
  271. use_motion_module=None,
  272. motion_module_type=None,
  273. motion_module_kwargs=None,
  274. ):
  275. super().__init__()
  276. resnets = []
  277. attentions = []
  278. motion_modules = []
  279. self.has_cross_attention = True
  280. self.attn_num_head_channels = attn_num_head_channels
  281. for i in range(num_layers):
  282. in_channels = in_channels if i == 0 else out_channels
  283. resnets.append(
  284. ResnetBlock3D(
  285. in_channels=in_channels,
  286. out_channels=out_channels,
  287. temb_channels=temb_channels,
  288. eps=resnet_eps,
  289. groups=resnet_groups,
  290. dropout=dropout,
  291. time_embedding_norm=resnet_time_scale_shift,
  292. non_linearity=resnet_act_fn,
  293. output_scale_factor=output_scale_factor,
  294. pre_norm=resnet_pre_norm,
  295. use_inflated_groupnorm=use_inflated_groupnorm,
  296. )
  297. )
  298. if dual_cross_attention:
  299. raise NotImplementedError
  300. attentions.append(
  301. Transformer3DModel(
  302. attn_num_head_channels,
  303. out_channels // attn_num_head_channels,
  304. in_channels=out_channels,
  305. num_layers=1,
  306. cross_attention_dim=cross_attention_dim,
  307. norm_num_groups=resnet_groups,
  308. use_linear_projection=use_linear_projection,
  309. only_cross_attention=only_cross_attention,
  310. upcast_attention=upcast_attention,
  311. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  312. unet_use_temporal_attention=unet_use_temporal_attention,
  313. )
  314. )
  315. motion_modules.append(
  316. get_motion_module(
  317. in_channels=out_channels,
  318. motion_module_type=motion_module_type,
  319. motion_module_kwargs=motion_module_kwargs,
  320. ) if use_motion_module else None
  321. )
  322. self.attentions = nn.ModuleList(attentions)
  323. self.resnets = nn.ModuleList(resnets)
  324. self.motion_modules = nn.ModuleList(motion_modules)
  325. if add_downsample:
  326. self.downsamplers = nn.ModuleList(
  327. [
  328. Downsample3D(
  329. out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
  330. )
  331. ]
  332. )
  333. else:
  334. self.downsamplers = None
  335. self.gradient_checkpointing = False
  336. def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
  337. output_states = ()
  338. for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
  339. if self.training and self.gradient_checkpointing:
  340. def create_custom_forward(module, return_dict=None):
  341. def custom_forward(*inputs):
  342. if return_dict is not None:
  343. return module(*inputs, return_dict=return_dict)
  344. else:
  345. return module(*inputs)
  346. return custom_forward
  347. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  348. hidden_states = torch.utils.checkpoint.checkpoint(
  349. create_custom_forward(attn, return_dict=False),
  350. hidden_states,
  351. encoder_hidden_states,
  352. )[0]
  353. if motion_module is not None:
  354. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  355. else:
  356. hidden_states = resnet(hidden_states, temb)
  357. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  358. # add motion module
  359. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  360. output_states += (hidden_states,)
  361. if self.downsamplers is not None:
  362. for downsampler in self.downsamplers:
  363. hidden_states = downsampler(hidden_states)
  364. output_states += (hidden_states,)
  365. return hidden_states, output_states
  366. class DownBlock3D(nn.Module):
  367. def __init__(
  368. self,
  369. in_channels: int,
  370. out_channels: int,
  371. temb_channels: int,
  372. dropout: float = 0.0,
  373. num_layers: int = 1,
  374. resnet_eps: float = 1e-6,
  375. resnet_time_scale_shift: str = "default",
  376. resnet_act_fn: str = "swish",
  377. resnet_groups: int = 32,
  378. resnet_pre_norm: bool = True,
  379. output_scale_factor=1.0,
  380. add_downsample=True,
  381. downsample_padding=1,
  382. use_inflated_groupnorm=None,
  383. use_motion_module=None,
  384. motion_module_type=None,
  385. motion_module_kwargs=None,
  386. ):
  387. super().__init__()
  388. resnets = []
  389. motion_modules = []
  390. for i in range(num_layers):
  391. in_channels = in_channels if i == 0 else out_channels
  392. resnets.append(
  393. ResnetBlock3D(
  394. in_channels=in_channels,
  395. out_channels=out_channels,
  396. temb_channels=temb_channels,
  397. eps=resnet_eps,
  398. groups=resnet_groups,
  399. dropout=dropout,
  400. time_embedding_norm=resnet_time_scale_shift,
  401. non_linearity=resnet_act_fn,
  402. output_scale_factor=output_scale_factor,
  403. pre_norm=resnet_pre_norm,
  404. use_inflated_groupnorm=use_inflated_groupnorm,
  405. )
  406. )
  407. motion_modules.append(
  408. get_motion_module(
  409. in_channels=out_channels,
  410. motion_module_type=motion_module_type,
  411. motion_module_kwargs=motion_module_kwargs,
  412. ) if use_motion_module else None
  413. )
  414. self.resnets = nn.ModuleList(resnets)
  415. self.motion_modules = nn.ModuleList(motion_modules)
  416. if add_downsample:
  417. self.downsamplers = nn.ModuleList(
  418. [
  419. Downsample3D(
  420. out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
  421. )
  422. ]
  423. )
  424. else:
  425. self.downsamplers = None
  426. self.gradient_checkpointing = False
  427. def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
  428. output_states = ()
  429. for resnet, motion_module in zip(self.resnets, self.motion_modules):
  430. if self.training and self.gradient_checkpointing:
  431. def create_custom_forward(module):
  432. def custom_forward(*inputs):
  433. return module(*inputs)
  434. return custom_forward
  435. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  436. if motion_module is not None:
  437. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  438. else:
  439. hidden_states = resnet(hidden_states, temb)
  440. # add motion module
  441. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  442. output_states += (hidden_states,)
  443. if self.downsamplers is not None:
  444. for downsampler in self.downsamplers:
  445. hidden_states = downsampler(hidden_states)
  446. output_states += (hidden_states,)
  447. return hidden_states, output_states
  448. class CrossAttnUpBlock3D(nn.Module):
  449. def __init__(
  450. self,
  451. in_channels: int,
  452. out_channels: int,
  453. prev_output_channel: int,
  454. temb_channels: int,
  455. dropout: float = 0.0,
  456. num_layers: int = 1,
  457. resnet_eps: float = 1e-6,
  458. resnet_time_scale_shift: str = "default",
  459. resnet_act_fn: str = "swish",
  460. resnet_groups: int = 32,
  461. resnet_pre_norm: bool = True,
  462. attn_num_head_channels=1,
  463. cross_attention_dim=1280,
  464. output_scale_factor=1.0,
  465. add_upsample=True,
  466. dual_cross_attention=False,
  467. use_linear_projection=False,
  468. only_cross_attention=False,
  469. upcast_attention=False,
  470. unet_use_cross_frame_attention=None,
  471. unet_use_temporal_attention=None,
  472. use_inflated_groupnorm=None,
  473. use_motion_module=None,
  474. motion_module_type=None,
  475. motion_module_kwargs=None,
  476. ):
  477. super().__init__()
  478. resnets = []
  479. attentions = []
  480. motion_modules = []
  481. self.has_cross_attention = True
  482. self.attn_num_head_channels = attn_num_head_channels
  483. for i in range(num_layers):
  484. res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
  485. resnet_in_channels = prev_output_channel if i == 0 else out_channels
  486. resnets.append(
  487. ResnetBlock3D(
  488. in_channels=resnet_in_channels + res_skip_channels,
  489. out_channels=out_channels,
  490. temb_channels=temb_channels,
  491. eps=resnet_eps,
  492. groups=resnet_groups,
  493. dropout=dropout,
  494. time_embedding_norm=resnet_time_scale_shift,
  495. non_linearity=resnet_act_fn,
  496. output_scale_factor=output_scale_factor,
  497. pre_norm=resnet_pre_norm,
  498. use_inflated_groupnorm=use_inflated_groupnorm,
  499. )
  500. )
  501. if dual_cross_attention:
  502. raise NotImplementedError
  503. attentions.append(
  504. Transformer3DModel(
  505. attn_num_head_channels,
  506. out_channels // attn_num_head_channels,
  507. in_channels=out_channels,
  508. num_layers=1,
  509. cross_attention_dim=cross_attention_dim,
  510. norm_num_groups=resnet_groups,
  511. use_linear_projection=use_linear_projection,
  512. only_cross_attention=only_cross_attention,
  513. upcast_attention=upcast_attention,
  514. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  515. unet_use_temporal_attention=unet_use_temporal_attention,
  516. )
  517. )
  518. motion_modules.append(
  519. get_motion_module(
  520. in_channels=out_channels,
  521. motion_module_type=motion_module_type,
  522. motion_module_kwargs=motion_module_kwargs,
  523. ) if use_motion_module else None
  524. )
  525. self.attentions = nn.ModuleList(attentions)
  526. self.resnets = nn.ModuleList(resnets)
  527. self.motion_modules = nn.ModuleList(motion_modules)
  528. if add_upsample:
  529. self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
  530. else:
  531. self.upsamplers = None
  532. self.gradient_checkpointing = False
  533. def forward(
  534. self,
  535. hidden_states,
  536. res_hidden_states_tuple,
  537. temb=None,
  538. encoder_hidden_states=None,
  539. upsample_size=None,
  540. attention_mask=None,
  541. ):
  542. for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
  543. # pop res hidden states
  544. res_hidden_states = res_hidden_states_tuple[-1]
  545. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  546. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  547. if self.training and self.gradient_checkpointing:
  548. def create_custom_forward(module, return_dict=None):
  549. def custom_forward(*inputs):
  550. if return_dict is not None:
  551. return module(*inputs, return_dict=return_dict)
  552. else:
  553. return module(*inputs)
  554. return custom_forward
  555. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  556. hidden_states = torch.utils.checkpoint.checkpoint(
  557. create_custom_forward(attn, return_dict=False),
  558. hidden_states,
  559. encoder_hidden_states,
  560. )[0]
  561. if motion_module is not None:
  562. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  563. else:
  564. hidden_states = resnet(hidden_states, temb)
  565. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  566. # add motion module
  567. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  568. if self.upsamplers is not None:
  569. for upsampler in self.upsamplers:
  570. hidden_states = upsampler(hidden_states, upsample_size)
  571. return hidden_states
  572. class UpBlock3D(nn.Module):
  573. def __init__(
  574. self,
  575. in_channels: int,
  576. prev_output_channel: int,
  577. out_channels: int,
  578. temb_channels: int,
  579. dropout: float = 0.0,
  580. num_layers: int = 1,
  581. resnet_eps: float = 1e-6,
  582. resnet_time_scale_shift: str = "default",
  583. resnet_act_fn: str = "swish",
  584. resnet_groups: int = 32,
  585. resnet_pre_norm: bool = True,
  586. output_scale_factor=1.0,
  587. add_upsample=True,
  588. use_inflated_groupnorm=None,
  589. use_motion_module=None,
  590. motion_module_type=None,
  591. motion_module_kwargs=None,
  592. ):
  593. super().__init__()
  594. resnets = []
  595. motion_modules = []
  596. for i in range(num_layers):
  597. res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
  598. resnet_in_channels = prev_output_channel if i == 0 else out_channels
  599. resnets.append(
  600. ResnetBlock3D(
  601. in_channels=resnet_in_channels + res_skip_channels,
  602. out_channels=out_channels,
  603. temb_channels=temb_channels,
  604. eps=resnet_eps,
  605. groups=resnet_groups,
  606. dropout=dropout,
  607. time_embedding_norm=resnet_time_scale_shift,
  608. non_linearity=resnet_act_fn,
  609. output_scale_factor=output_scale_factor,
  610. pre_norm=resnet_pre_norm,
  611. use_inflated_groupnorm=use_inflated_groupnorm,
  612. )
  613. )
  614. motion_modules.append(
  615. get_motion_module(
  616. in_channels=out_channels,
  617. motion_module_type=motion_module_type,
  618. motion_module_kwargs=motion_module_kwargs,
  619. ) if use_motion_module else None
  620. )
  621. self.resnets = nn.ModuleList(resnets)
  622. self.motion_modules = nn.ModuleList(motion_modules)
  623. if add_upsample:
  624. self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
  625. else:
  626. self.upsamplers = None
  627. self.gradient_checkpointing = False
  628. def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
  629. for resnet, motion_module in zip(self.resnets, self.motion_modules):
  630. # pop res hidden states
  631. res_hidden_states = res_hidden_states_tuple[-1]
  632. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  633. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  634. if self.training and self.gradient_checkpointing:
  635. def create_custom_forward(module):
  636. def custom_forward(*inputs):
  637. return module(*inputs)
  638. return custom_forward
  639. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  640. if motion_module is not None:
  641. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  642. else:
  643. hidden_states = resnet(hidden_states, temb)
  644. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  645. if self.upsamplers is not None:
  646. for upsampler in self.upsamplers:
  647. hidden_states = upsampler(hidden_states, upsample_size)
  648. return hidden_states