unet_blocks.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
  2. import torch
  3. from torch import nn
  4. from .attention import Transformer3DModel
  5. from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
  6. from .motion_module import get_motion_module
  7. import pdb
  8. def get_down_block(
  9. down_block_type,
  10. num_layers,
  11. in_channels,
  12. out_channels,
  13. temb_channels,
  14. add_downsample,
  15. resnet_eps,
  16. resnet_act_fn,
  17. attn_num_head_channels,
  18. resnet_groups=None,
  19. cross_attention_dim=None,
  20. downsample_padding=None,
  21. dual_cross_attention=False,
  22. use_linear_projection=False,
  23. only_cross_attention=False,
  24. upcast_attention=False,
  25. resnet_time_scale_shift="default",
  26. unet_use_cross_frame_attention=None,
  27. unet_use_temporal_attention=None,
  28. use_motion_module=None,
  29. motion_module_type=None,
  30. motion_module_kwargs=None,
  31. ):
  32. down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
  33. if down_block_type == "DownBlock3D":
  34. return DownBlock3D(
  35. num_layers=num_layers,
  36. in_channels=in_channels,
  37. out_channels=out_channels,
  38. temb_channels=temb_channels,
  39. add_downsample=add_downsample,
  40. resnet_eps=resnet_eps,
  41. resnet_act_fn=resnet_act_fn,
  42. resnet_groups=resnet_groups,
  43. downsample_padding=downsample_padding,
  44. resnet_time_scale_shift=resnet_time_scale_shift,
  45. use_motion_module=use_motion_module,
  46. motion_module_type=motion_module_type,
  47. motion_module_kwargs=motion_module_kwargs,
  48. )
  49. elif down_block_type == "CrossAttnDownBlock3D":
  50. if cross_attention_dim is None:
  51. raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
  52. return CrossAttnDownBlock3D(
  53. num_layers=num_layers,
  54. in_channels=in_channels,
  55. out_channels=out_channels,
  56. temb_channels=temb_channels,
  57. add_downsample=add_downsample,
  58. resnet_eps=resnet_eps,
  59. resnet_act_fn=resnet_act_fn,
  60. resnet_groups=resnet_groups,
  61. downsample_padding=downsample_padding,
  62. cross_attention_dim=cross_attention_dim,
  63. attn_num_head_channels=attn_num_head_channels,
  64. dual_cross_attention=dual_cross_attention,
  65. use_linear_projection=use_linear_projection,
  66. only_cross_attention=only_cross_attention,
  67. upcast_attention=upcast_attention,
  68. resnet_time_scale_shift=resnet_time_scale_shift,
  69. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  70. unet_use_temporal_attention=unet_use_temporal_attention,
  71. use_motion_module=use_motion_module,
  72. motion_module_type=motion_module_type,
  73. motion_module_kwargs=motion_module_kwargs,
  74. )
  75. raise ValueError(f"{down_block_type} does not exist.")
  76. def get_up_block(
  77. up_block_type,
  78. num_layers,
  79. in_channels,
  80. out_channels,
  81. prev_output_channel,
  82. temb_channels,
  83. add_upsample,
  84. resnet_eps,
  85. resnet_act_fn,
  86. attn_num_head_channels,
  87. resnet_groups=None,
  88. cross_attention_dim=None,
  89. dual_cross_attention=False,
  90. use_linear_projection=False,
  91. only_cross_attention=False,
  92. upcast_attention=False,
  93. resnet_time_scale_shift="default",
  94. unet_use_cross_frame_attention=None,
  95. unet_use_temporal_attention=None,
  96. use_motion_module=None,
  97. motion_module_type=None,
  98. motion_module_kwargs=None,
  99. ):
  100. up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
  101. if up_block_type == "UpBlock3D":
  102. return UpBlock3D(
  103. num_layers=num_layers,
  104. in_channels=in_channels,
  105. out_channels=out_channels,
  106. prev_output_channel=prev_output_channel,
  107. temb_channels=temb_channels,
  108. add_upsample=add_upsample,
  109. resnet_eps=resnet_eps,
  110. resnet_act_fn=resnet_act_fn,
  111. resnet_groups=resnet_groups,
  112. resnet_time_scale_shift=resnet_time_scale_shift,
  113. use_motion_module=use_motion_module,
  114. motion_module_type=motion_module_type,
  115. motion_module_kwargs=motion_module_kwargs,
  116. )
  117. elif up_block_type == "CrossAttnUpBlock3D":
  118. if cross_attention_dim is None:
  119. raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
  120. return CrossAttnUpBlock3D(
  121. num_layers=num_layers,
  122. in_channels=in_channels,
  123. out_channels=out_channels,
  124. prev_output_channel=prev_output_channel,
  125. temb_channels=temb_channels,
  126. add_upsample=add_upsample,
  127. resnet_eps=resnet_eps,
  128. resnet_act_fn=resnet_act_fn,
  129. resnet_groups=resnet_groups,
  130. cross_attention_dim=cross_attention_dim,
  131. attn_num_head_channels=attn_num_head_channels,
  132. dual_cross_attention=dual_cross_attention,
  133. use_linear_projection=use_linear_projection,
  134. only_cross_attention=only_cross_attention,
  135. upcast_attention=upcast_attention,
  136. resnet_time_scale_shift=resnet_time_scale_shift,
  137. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  138. unet_use_temporal_attention=unet_use_temporal_attention,
  139. use_motion_module=use_motion_module,
  140. motion_module_type=motion_module_type,
  141. motion_module_kwargs=motion_module_kwargs,
  142. )
  143. raise ValueError(f"{up_block_type} does not exist.")
  144. class UNetMidBlock3DCrossAttn(nn.Module):
  145. def __init__(
  146. self,
  147. in_channels: int,
  148. temb_channels: int,
  149. dropout: float = 0.0,
  150. num_layers: int = 1,
  151. resnet_eps: float = 1e-6,
  152. resnet_time_scale_shift: str = "default",
  153. resnet_act_fn: str = "swish",
  154. resnet_groups: int = 32,
  155. resnet_pre_norm: bool = True,
  156. attn_num_head_channels=1,
  157. output_scale_factor=1.0,
  158. cross_attention_dim=1280,
  159. dual_cross_attention=False,
  160. use_linear_projection=False,
  161. upcast_attention=False,
  162. unet_use_cross_frame_attention=None,
  163. unet_use_temporal_attention=None,
  164. use_motion_module=None,
  165. motion_module_type=None,
  166. motion_module_kwargs=None,
  167. ):
  168. super().__init__()
  169. self.has_cross_attention = True
  170. self.attn_num_head_channels = attn_num_head_channels
  171. resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
  172. # there is always at least one resnet
  173. resnets = [
  174. ResnetBlock3D(
  175. in_channels=in_channels,
  176. out_channels=in_channels,
  177. temb_channels=temb_channels,
  178. eps=resnet_eps,
  179. groups=resnet_groups,
  180. dropout=dropout,
  181. time_embedding_norm=resnet_time_scale_shift,
  182. non_linearity=resnet_act_fn,
  183. output_scale_factor=output_scale_factor,
  184. pre_norm=resnet_pre_norm,
  185. )
  186. ]
  187. attentions = []
  188. motion_modules = []
  189. for _ in range(num_layers):
  190. if dual_cross_attention:
  191. raise NotImplementedError
  192. attentions.append(
  193. Transformer3DModel(
  194. attn_num_head_channels,
  195. in_channels // attn_num_head_channels,
  196. in_channels=in_channels,
  197. num_layers=1,
  198. cross_attention_dim=cross_attention_dim,
  199. norm_num_groups=resnet_groups,
  200. use_linear_projection=use_linear_projection,
  201. upcast_attention=upcast_attention,
  202. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  203. unet_use_temporal_attention=unet_use_temporal_attention,
  204. )
  205. )
  206. motion_modules.append(
  207. get_motion_module(
  208. in_channels=in_channels,
  209. motion_module_type=motion_module_type,
  210. motion_module_kwargs=motion_module_kwargs,
  211. ) if use_motion_module else None
  212. )
  213. resnets.append(
  214. ResnetBlock3D(
  215. in_channels=in_channels,
  216. out_channels=in_channels,
  217. temb_channels=temb_channels,
  218. eps=resnet_eps,
  219. groups=resnet_groups,
  220. dropout=dropout,
  221. time_embedding_norm=resnet_time_scale_shift,
  222. non_linearity=resnet_act_fn,
  223. output_scale_factor=output_scale_factor,
  224. pre_norm=resnet_pre_norm,
  225. )
  226. )
  227. self.attentions = nn.ModuleList(attentions)
  228. self.resnets = nn.ModuleList(resnets)
  229. self.motion_modules = nn.ModuleList(motion_modules)
  230. def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
  231. hidden_states = self.resnets[0](hidden_states, temb)
  232. for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
  233. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  234. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  235. hidden_states = resnet(hidden_states, temb)
  236. return hidden_states
  237. class CrossAttnDownBlock3D(nn.Module):
  238. def __init__(
  239. self,
  240. in_channels: int,
  241. out_channels: int,
  242. temb_channels: int,
  243. dropout: float = 0.0,
  244. num_layers: int = 1,
  245. resnet_eps: float = 1e-6,
  246. resnet_time_scale_shift: str = "default",
  247. resnet_act_fn: str = "swish",
  248. resnet_groups: int = 32,
  249. resnet_pre_norm: bool = True,
  250. attn_num_head_channels=1,
  251. cross_attention_dim=1280,
  252. output_scale_factor=1.0,
  253. downsample_padding=1,
  254. add_downsample=True,
  255. dual_cross_attention=False,
  256. use_linear_projection=False,
  257. only_cross_attention=False,
  258. upcast_attention=False,
  259. unet_use_cross_frame_attention=None,
  260. unet_use_temporal_attention=None,
  261. use_motion_module=None,
  262. motion_module_type=None,
  263. motion_module_kwargs=None,
  264. ):
  265. super().__init__()
  266. resnets = []
  267. attentions = []
  268. motion_modules = []
  269. self.has_cross_attention = True
  270. self.attn_num_head_channels = attn_num_head_channels
  271. for i in range(num_layers):
  272. in_channels = in_channels if i == 0 else out_channels
  273. resnets.append(
  274. ResnetBlock3D(
  275. in_channels=in_channels,
  276. out_channels=out_channels,
  277. temb_channels=temb_channels,
  278. eps=resnet_eps,
  279. groups=resnet_groups,
  280. dropout=dropout,
  281. time_embedding_norm=resnet_time_scale_shift,
  282. non_linearity=resnet_act_fn,
  283. output_scale_factor=output_scale_factor,
  284. pre_norm=resnet_pre_norm,
  285. )
  286. )
  287. if dual_cross_attention:
  288. raise NotImplementedError
  289. attentions.append(
  290. Transformer3DModel(
  291. attn_num_head_channels,
  292. out_channels // attn_num_head_channels,
  293. in_channels=out_channels,
  294. num_layers=1,
  295. cross_attention_dim=cross_attention_dim,
  296. norm_num_groups=resnet_groups,
  297. use_linear_projection=use_linear_projection,
  298. only_cross_attention=only_cross_attention,
  299. upcast_attention=upcast_attention,
  300. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  301. unet_use_temporal_attention=unet_use_temporal_attention,
  302. )
  303. )
  304. motion_modules.append(
  305. get_motion_module(
  306. in_channels=out_channels,
  307. motion_module_type=motion_module_type,
  308. motion_module_kwargs=motion_module_kwargs,
  309. ) if use_motion_module else None
  310. )
  311. self.attentions = nn.ModuleList(attentions)
  312. self.resnets = nn.ModuleList(resnets)
  313. self.motion_modules = nn.ModuleList(motion_modules)
  314. if add_downsample:
  315. self.downsamplers = nn.ModuleList(
  316. [
  317. Downsample3D(
  318. out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
  319. )
  320. ]
  321. )
  322. else:
  323. self.downsamplers = None
  324. self.gradient_checkpointing = False
  325. def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
  326. output_states = ()
  327. for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
  328. if self.training and self.gradient_checkpointing:
  329. def create_custom_forward(module, return_dict=None):
  330. def custom_forward(*inputs):
  331. if return_dict is not None:
  332. return module(*inputs, return_dict=return_dict)
  333. else:
  334. return module(*inputs)
  335. return custom_forward
  336. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  337. hidden_states = torch.utils.checkpoint.checkpoint(
  338. create_custom_forward(attn, return_dict=False),
  339. hidden_states,
  340. encoder_hidden_states,
  341. )[0]
  342. if motion_module is not None:
  343. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  344. else:
  345. hidden_states = resnet(hidden_states, temb)
  346. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  347. # add motion module
  348. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  349. output_states += (hidden_states,)
  350. if self.downsamplers is not None:
  351. for downsampler in self.downsamplers:
  352. hidden_states = downsampler(hidden_states)
  353. output_states += (hidden_states,)
  354. return hidden_states, output_states
  355. class DownBlock3D(nn.Module):
  356. def __init__(
  357. self,
  358. in_channels: int,
  359. out_channels: int,
  360. temb_channels: int,
  361. dropout: float = 0.0,
  362. num_layers: int = 1,
  363. resnet_eps: float = 1e-6,
  364. resnet_time_scale_shift: str = "default",
  365. resnet_act_fn: str = "swish",
  366. resnet_groups: int = 32,
  367. resnet_pre_norm: bool = True,
  368. output_scale_factor=1.0,
  369. add_downsample=True,
  370. downsample_padding=1,
  371. use_motion_module=None,
  372. motion_module_type=None,
  373. motion_module_kwargs=None,
  374. ):
  375. super().__init__()
  376. resnets = []
  377. motion_modules = []
  378. for i in range(num_layers):
  379. in_channels = in_channels if i == 0 else out_channels
  380. resnets.append(
  381. ResnetBlock3D(
  382. in_channels=in_channels,
  383. out_channels=out_channels,
  384. temb_channels=temb_channels,
  385. eps=resnet_eps,
  386. groups=resnet_groups,
  387. dropout=dropout,
  388. time_embedding_norm=resnet_time_scale_shift,
  389. non_linearity=resnet_act_fn,
  390. output_scale_factor=output_scale_factor,
  391. pre_norm=resnet_pre_norm,
  392. )
  393. )
  394. motion_modules.append(
  395. get_motion_module(
  396. in_channels=out_channels,
  397. motion_module_type=motion_module_type,
  398. motion_module_kwargs=motion_module_kwargs,
  399. ) if use_motion_module else None
  400. )
  401. self.resnets = nn.ModuleList(resnets)
  402. self.motion_modules = nn.ModuleList(motion_modules)
  403. if add_downsample:
  404. self.downsamplers = nn.ModuleList(
  405. [
  406. Downsample3D(
  407. out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
  408. )
  409. ]
  410. )
  411. else:
  412. self.downsamplers = None
  413. self.gradient_checkpointing = False
  414. def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
  415. output_states = ()
  416. for resnet, motion_module in zip(self.resnets, self.motion_modules):
  417. if self.training and self.gradient_checkpointing:
  418. def create_custom_forward(module):
  419. def custom_forward(*inputs):
  420. return module(*inputs)
  421. return custom_forward
  422. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  423. if motion_module is not None:
  424. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  425. else:
  426. hidden_states = resnet(hidden_states, temb)
  427. # add motion module
  428. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  429. output_states += (hidden_states,)
  430. if self.downsamplers is not None:
  431. for downsampler in self.downsamplers:
  432. hidden_states = downsampler(hidden_states)
  433. output_states += (hidden_states,)
  434. return hidden_states, output_states
  435. class CrossAttnUpBlock3D(nn.Module):
  436. def __init__(
  437. self,
  438. in_channels: int,
  439. out_channels: int,
  440. prev_output_channel: int,
  441. temb_channels: int,
  442. dropout: float = 0.0,
  443. num_layers: int = 1,
  444. resnet_eps: float = 1e-6,
  445. resnet_time_scale_shift: str = "default",
  446. resnet_act_fn: str = "swish",
  447. resnet_groups: int = 32,
  448. resnet_pre_norm: bool = True,
  449. attn_num_head_channels=1,
  450. cross_attention_dim=1280,
  451. output_scale_factor=1.0,
  452. add_upsample=True,
  453. dual_cross_attention=False,
  454. use_linear_projection=False,
  455. only_cross_attention=False,
  456. upcast_attention=False,
  457. unet_use_cross_frame_attention=None,
  458. unet_use_temporal_attention=None,
  459. use_motion_module=None,
  460. motion_module_type=None,
  461. motion_module_kwargs=None,
  462. ):
  463. super().__init__()
  464. resnets = []
  465. attentions = []
  466. motion_modules = []
  467. self.has_cross_attention = True
  468. self.attn_num_head_channels = attn_num_head_channels
  469. for i in range(num_layers):
  470. res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
  471. resnet_in_channels = prev_output_channel if i == 0 else out_channels
  472. resnets.append(
  473. ResnetBlock3D(
  474. in_channels=resnet_in_channels + res_skip_channels,
  475. out_channels=out_channels,
  476. temb_channels=temb_channels,
  477. eps=resnet_eps,
  478. groups=resnet_groups,
  479. dropout=dropout,
  480. time_embedding_norm=resnet_time_scale_shift,
  481. non_linearity=resnet_act_fn,
  482. output_scale_factor=output_scale_factor,
  483. pre_norm=resnet_pre_norm,
  484. )
  485. )
  486. if dual_cross_attention:
  487. raise NotImplementedError
  488. attentions.append(
  489. Transformer3DModel(
  490. attn_num_head_channels,
  491. out_channels // attn_num_head_channels,
  492. in_channels=out_channels,
  493. num_layers=1,
  494. cross_attention_dim=cross_attention_dim,
  495. norm_num_groups=resnet_groups,
  496. use_linear_projection=use_linear_projection,
  497. only_cross_attention=only_cross_attention,
  498. upcast_attention=upcast_attention,
  499. unet_use_cross_frame_attention=unet_use_cross_frame_attention,
  500. unet_use_temporal_attention=unet_use_temporal_attention,
  501. )
  502. )
  503. motion_modules.append(
  504. get_motion_module(
  505. in_channels=out_channels,
  506. motion_module_type=motion_module_type,
  507. motion_module_kwargs=motion_module_kwargs,
  508. ) if use_motion_module else None
  509. )
  510. self.attentions = nn.ModuleList(attentions)
  511. self.resnets = nn.ModuleList(resnets)
  512. self.motion_modules = nn.ModuleList(motion_modules)
  513. if add_upsample:
  514. self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
  515. else:
  516. self.upsamplers = None
  517. self.gradient_checkpointing = False
  518. def forward(
  519. self,
  520. hidden_states,
  521. res_hidden_states_tuple,
  522. temb=None,
  523. encoder_hidden_states=None,
  524. upsample_size=None,
  525. attention_mask=None,
  526. ):
  527. for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
  528. # pop res hidden states
  529. res_hidden_states = res_hidden_states_tuple[-1]
  530. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  531. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  532. if self.training and self.gradient_checkpointing:
  533. def create_custom_forward(module, return_dict=None):
  534. def custom_forward(*inputs):
  535. if return_dict is not None:
  536. return module(*inputs, return_dict=return_dict)
  537. else:
  538. return module(*inputs)
  539. return custom_forward
  540. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  541. hidden_states = torch.utils.checkpoint.checkpoint(
  542. create_custom_forward(attn, return_dict=False),
  543. hidden_states,
  544. encoder_hidden_states,
  545. )[0]
  546. if motion_module is not None:
  547. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  548. else:
  549. hidden_states = resnet(hidden_states, temb)
  550. hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  551. # add motion module
  552. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  553. if self.upsamplers is not None:
  554. for upsampler in self.upsamplers:
  555. hidden_states = upsampler(hidden_states, upsample_size)
  556. return hidden_states
  557. class UpBlock3D(nn.Module):
  558. def __init__(
  559. self,
  560. in_channels: int,
  561. prev_output_channel: int,
  562. out_channels: int,
  563. temb_channels: int,
  564. dropout: float = 0.0,
  565. num_layers: int = 1,
  566. resnet_eps: float = 1e-6,
  567. resnet_time_scale_shift: str = "default",
  568. resnet_act_fn: str = "swish",
  569. resnet_groups: int = 32,
  570. resnet_pre_norm: bool = True,
  571. output_scale_factor=1.0,
  572. add_upsample=True,
  573. use_motion_module=None,
  574. motion_module_type=None,
  575. motion_module_kwargs=None,
  576. ):
  577. super().__init__()
  578. resnets = []
  579. motion_modules = []
  580. for i in range(num_layers):
  581. res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
  582. resnet_in_channels = prev_output_channel if i == 0 else out_channels
  583. resnets.append(
  584. ResnetBlock3D(
  585. in_channels=resnet_in_channels + res_skip_channels,
  586. out_channels=out_channels,
  587. temb_channels=temb_channels,
  588. eps=resnet_eps,
  589. groups=resnet_groups,
  590. dropout=dropout,
  591. time_embedding_norm=resnet_time_scale_shift,
  592. non_linearity=resnet_act_fn,
  593. output_scale_factor=output_scale_factor,
  594. pre_norm=resnet_pre_norm,
  595. )
  596. )
  597. motion_modules.append(
  598. get_motion_module(
  599. in_channels=out_channels,
  600. motion_module_type=motion_module_type,
  601. motion_module_kwargs=motion_module_kwargs,
  602. ) if use_motion_module else None
  603. )
  604. self.resnets = nn.ModuleList(resnets)
  605. self.motion_modules = nn.ModuleList(motion_modules)
  606. if add_upsample:
  607. self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
  608. else:
  609. self.upsamplers = None
  610. self.gradient_checkpointing = False
  611. def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
  612. for resnet, motion_module in zip(self.resnets, self.motion_modules):
  613. # pop res hidden states
  614. res_hidden_states = res_hidden_states_tuple[-1]
  615. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  616. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  617. if self.training and self.gradient_checkpointing:
  618. def create_custom_forward(module):
  619. def custom_forward(*inputs):
  620. return module(*inputs)
  621. return custom_forward
  622. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
  623. if motion_module is not None:
  624. hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
  625. else:
  626. hidden_states = resnet(hidden_states, temb)
  627. hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
  628. if self.upsamplers is not None:
  629. for upsampler in self.upsamplers:
  630. hidden_states = upsampler(hidden_states, upsample_size)
  631. return hidden_states