nn.py 80 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447
  1. from torch import nn
  2. from federatedml.nn.backend.torch.base import FateTorchLayer, FateTorchLoss
  3. from federatedml.nn.backend.torch.base import Sequential
  4. class Bilinear(nn.modules.linear.Bilinear, FateTorchLayer):
  5. def __init__(
  6. self,
  7. in1_features,
  8. in2_features,
  9. out_features,
  10. bias=True,
  11. device=None,
  12. dtype=None,
  13. **kwargs):
  14. FateTorchLayer.__init__(self)
  15. self.param_dict['bias'] = bias
  16. self.param_dict['device'] = device
  17. self.param_dict['dtype'] = dtype
  18. self.param_dict['in1_features'] = in1_features
  19. self.param_dict['in2_features'] = in2_features
  20. self.param_dict['out_features'] = out_features
  21. self.param_dict.update(kwargs)
  22. nn.modules.linear.Bilinear.__init__(self, **self.param_dict)
  23. class Identity(nn.modules.linear.Identity, FateTorchLayer):
  24. def __init__(self, **kwargs):
  25. FateTorchLayer.__init__(self)
  26. self.param_dict.update(kwargs)
  27. nn.modules.linear.Identity.__init__(self, **self.param_dict)
  28. class LazyLinear(nn.modules.linear.LazyLinear, FateTorchLayer):
  29. def __init__(
  30. self,
  31. out_features,
  32. bias=True,
  33. device=None,
  34. dtype=None,
  35. **kwargs):
  36. FateTorchLayer.__init__(self)
  37. self.param_dict['bias'] = bias
  38. self.param_dict['device'] = device
  39. self.param_dict['dtype'] = dtype
  40. self.param_dict['out_features'] = out_features
  41. self.param_dict.update(kwargs)
  42. nn.modules.linear.LazyLinear.__init__(self, **self.param_dict)
  43. class Linear(nn.modules.linear.Linear, FateTorchLayer):
  44. def __init__(
  45. self,
  46. in_features,
  47. out_features,
  48. bias=True,
  49. device=None,
  50. dtype=None,
  51. **kwargs):
  52. FateTorchLayer.__init__(self)
  53. self.param_dict['bias'] = bias
  54. self.param_dict['device'] = device
  55. self.param_dict['dtype'] = dtype
  56. self.param_dict['in_features'] = in_features
  57. self.param_dict['out_features'] = out_features
  58. self.param_dict.update(kwargs)
  59. nn.modules.linear.Linear.__init__(self, **self.param_dict)
  60. class NonDynamicallyQuantizableLinear(
  61. nn.modules.linear.NonDynamicallyQuantizableLinear,
  62. FateTorchLayer):
  63. def __init__(
  64. self,
  65. in_features,
  66. out_features,
  67. bias=True,
  68. device=None,
  69. dtype=None,
  70. **kwargs):
  71. FateTorchLayer.__init__(self)
  72. self.param_dict['bias'] = bias
  73. self.param_dict['device'] = device
  74. self.param_dict['dtype'] = dtype
  75. self.param_dict['in_features'] = in_features
  76. self.param_dict['out_features'] = out_features
  77. self.param_dict.update(kwargs)
  78. nn.modules.linear.NonDynamicallyQuantizableLinear.__init__(
  79. self, **self.param_dict)
  80. class GRU(nn.modules.rnn.GRU, FateTorchLayer):
  81. def __init__(self, **kwargs):
  82. FateTorchLayer.__init__(self)
  83. self.param_dict.update(kwargs)
  84. nn.modules.rnn.GRU.__init__(self, **self.param_dict)
  85. class GRUCell(nn.modules.rnn.GRUCell, FateTorchLayer):
  86. def __init__(
  87. self,
  88. input_size,
  89. hidden_size,
  90. bias=True,
  91. device=None,
  92. dtype=None,
  93. **kwargs):
  94. FateTorchLayer.__init__(self)
  95. self.param_dict['bias'] = bias
  96. self.param_dict['device'] = device
  97. self.param_dict['dtype'] = dtype
  98. self.param_dict['input_size'] = input_size
  99. self.param_dict['hidden_size'] = hidden_size
  100. self.param_dict.update(kwargs)
  101. nn.modules.rnn.GRUCell.__init__(self, **self.param_dict)
  102. class LSTM(nn.modules.rnn.LSTM, FateTorchLayer):
  103. def __init__(self, **kwargs):
  104. FateTorchLayer.__init__(self)
  105. self.param_dict.update(kwargs)
  106. nn.modules.rnn.LSTM.__init__(self, **self.param_dict)
  107. class LSTMCell(nn.modules.rnn.LSTMCell, FateTorchLayer):
  108. def __init__(
  109. self,
  110. input_size,
  111. hidden_size,
  112. bias=True,
  113. device=None,
  114. dtype=None,
  115. **kwargs):
  116. FateTorchLayer.__init__(self)
  117. self.param_dict['bias'] = bias
  118. self.param_dict['device'] = device
  119. self.param_dict['dtype'] = dtype
  120. self.param_dict['input_size'] = input_size
  121. self.param_dict['hidden_size'] = hidden_size
  122. self.param_dict.update(kwargs)
  123. nn.modules.rnn.LSTMCell.__init__(self, **self.param_dict)
  124. class RNN(nn.modules.rnn.RNN, FateTorchLayer):
  125. def __init__(self, **kwargs):
  126. FateTorchLayer.__init__(self)
  127. self.param_dict.update(kwargs)
  128. nn.modules.rnn.RNN.__init__(self, **self.param_dict)
  129. class RNNBase(nn.modules.rnn.RNNBase, FateTorchLayer):
  130. def __init__(
  131. self,
  132. mode,
  133. input_size,
  134. hidden_size,
  135. num_layers=1,
  136. bias=True,
  137. batch_first=False,
  138. dropout=0.0,
  139. bidirectional=False,
  140. proj_size=0,
  141. device=None,
  142. dtype=None,
  143. **kwargs):
  144. FateTorchLayer.__init__(self)
  145. self.param_dict['num_layers'] = num_layers
  146. self.param_dict['bias'] = bias
  147. self.param_dict['batch_first'] = batch_first
  148. self.param_dict['dropout'] = dropout
  149. self.param_dict['bidirectional'] = bidirectional
  150. self.param_dict['proj_size'] = proj_size
  151. self.param_dict['device'] = device
  152. self.param_dict['dtype'] = dtype
  153. self.param_dict['mode'] = mode
  154. self.param_dict['input_size'] = input_size
  155. self.param_dict['hidden_size'] = hidden_size
  156. self.param_dict.update(kwargs)
  157. nn.modules.rnn.RNNBase.__init__(self, **self.param_dict)
  158. class RNNCell(nn.modules.rnn.RNNCell, FateTorchLayer):
  159. def __init__(
  160. self,
  161. input_size,
  162. hidden_size,
  163. bias=True,
  164. nonlinearity='tanh',
  165. device=None,
  166. dtype=None,
  167. **kwargs):
  168. FateTorchLayer.__init__(self)
  169. self.param_dict['bias'] = bias
  170. self.param_dict['nonlinearity'] = nonlinearity
  171. self.param_dict['device'] = device
  172. self.param_dict['dtype'] = dtype
  173. self.param_dict['input_size'] = input_size
  174. self.param_dict['hidden_size'] = hidden_size
  175. self.param_dict.update(kwargs)
  176. nn.modules.rnn.RNNCell.__init__(self, **self.param_dict)
  177. class RNNCellBase(nn.modules.rnn.RNNCellBase, FateTorchLayer):
  178. def __init__(
  179. self,
  180. input_size,
  181. hidden_size,
  182. bias,
  183. num_chunks,
  184. device=None,
  185. dtype=None,
  186. **kwargs):
  187. FateTorchLayer.__init__(self)
  188. self.param_dict['device'] = device
  189. self.param_dict['dtype'] = dtype
  190. self.param_dict['input_size'] = input_size
  191. self.param_dict['hidden_size'] = hidden_size
  192. self.param_dict['bias'] = bias
  193. self.param_dict['num_chunks'] = num_chunks
  194. self.param_dict.update(kwargs)
  195. nn.modules.rnn.RNNCellBase.__init__(self, **self.param_dict)
  196. class Embedding(nn.modules.sparse.Embedding, FateTorchLayer):
  197. def __init__(
  198. self,
  199. num_embeddings,
  200. embedding_dim,
  201. padding_idx=None,
  202. max_norm=None,
  203. norm_type=2.0,
  204. scale_grad_by_freq=False,
  205. sparse=False,
  206. _weight=None,
  207. device=None,
  208. dtype=None,
  209. **kwargs):
  210. FateTorchLayer.__init__(self)
  211. self.param_dict['padding_idx'] = padding_idx
  212. self.param_dict['max_norm'] = max_norm
  213. self.param_dict['norm_type'] = norm_type
  214. self.param_dict['scale_grad_by_freq'] = scale_grad_by_freq
  215. self.param_dict['sparse'] = sparse
  216. self.param_dict['_weight'] = _weight
  217. self.param_dict['device'] = device
  218. self.param_dict['dtype'] = dtype
  219. self.param_dict['num_embeddings'] = num_embeddings
  220. self.param_dict['embedding_dim'] = embedding_dim
  221. self.param_dict.update(kwargs)
  222. nn.modules.sparse.Embedding.__init__(self, **self.param_dict)
  223. class EmbeddingBag(nn.modules.sparse.EmbeddingBag, FateTorchLayer):
  224. def __init__(
  225. self,
  226. num_embeddings,
  227. embedding_dim,
  228. max_norm=None,
  229. norm_type=2.0,
  230. scale_grad_by_freq=False,
  231. mode='mean',
  232. sparse=False,
  233. _weight=None,
  234. include_last_offset=False,
  235. padding_idx=None,
  236. device=None,
  237. dtype=None,
  238. **kwargs):
  239. FateTorchLayer.__init__(self)
  240. self.param_dict['max_norm'] = max_norm
  241. self.param_dict['norm_type'] = norm_type
  242. self.param_dict['scale_grad_by_freq'] = scale_grad_by_freq
  243. self.param_dict['mode'] = mode
  244. self.param_dict['sparse'] = sparse
  245. self.param_dict['_weight'] = _weight
  246. self.param_dict['include_last_offset'] = include_last_offset
  247. self.param_dict['padding_idx'] = padding_idx
  248. self.param_dict['device'] = device
  249. self.param_dict['dtype'] = dtype
  250. self.param_dict['num_embeddings'] = num_embeddings
  251. self.param_dict['embedding_dim'] = embedding_dim
  252. self.param_dict.update(kwargs)
  253. nn.modules.sparse.EmbeddingBag.__init__(self, **self.param_dict)
  254. class AlphaDropout(nn.modules.dropout.AlphaDropout, FateTorchLayer):
  255. def __init__(self, p=0.5, inplace=False, **kwargs):
  256. FateTorchLayer.__init__(self)
  257. self.param_dict['p'] = p
  258. self.param_dict['inplace'] = inplace
  259. self.param_dict.update(kwargs)
  260. nn.modules.dropout.AlphaDropout.__init__(self, **self.param_dict)
  261. class Dropout(nn.modules.dropout.Dropout, FateTorchLayer):
  262. def __init__(self, p=0.5, inplace=False, **kwargs):
  263. FateTorchLayer.__init__(self)
  264. self.param_dict['p'] = p
  265. self.param_dict['inplace'] = inplace
  266. self.param_dict.update(kwargs)
  267. nn.modules.dropout.Dropout.__init__(self, **self.param_dict)
  268. class Dropout1d(nn.modules.dropout.Dropout1d, FateTorchLayer):
  269. def __init__(self, p=0.5, inplace=False, **kwargs):
  270. FateTorchLayer.__init__(self)
  271. self.param_dict['p'] = p
  272. self.param_dict['inplace'] = inplace
  273. self.param_dict.update(kwargs)
  274. nn.modules.dropout.Dropout1d.__init__(self, **self.param_dict)
  275. class Dropout2d(nn.modules.dropout.Dropout2d, FateTorchLayer):
  276. def __init__(self, p=0.5, inplace=False, **kwargs):
  277. FateTorchLayer.__init__(self)
  278. self.param_dict['p'] = p
  279. self.param_dict['inplace'] = inplace
  280. self.param_dict.update(kwargs)
  281. nn.modules.dropout.Dropout2d.__init__(self, **self.param_dict)
  282. class Dropout3d(nn.modules.dropout.Dropout3d, FateTorchLayer):
  283. def __init__(self, p=0.5, inplace=False, **kwargs):
  284. FateTorchLayer.__init__(self)
  285. self.param_dict['p'] = p
  286. self.param_dict['inplace'] = inplace
  287. self.param_dict.update(kwargs)
  288. nn.modules.dropout.Dropout3d.__init__(self, **self.param_dict)
  289. class FeatureAlphaDropout(
  290. nn.modules.dropout.FeatureAlphaDropout,
  291. FateTorchLayer):
  292. def __init__(self, p=0.5, inplace=False, **kwargs):
  293. FateTorchLayer.__init__(self)
  294. self.param_dict['p'] = p
  295. self.param_dict['inplace'] = inplace
  296. self.param_dict.update(kwargs)
  297. nn.modules.dropout.FeatureAlphaDropout.__init__(
  298. self, **self.param_dict)
  299. class _DropoutNd(nn.modules.dropout._DropoutNd, FateTorchLayer):
  300. def __init__(self, p=0.5, inplace=False, **kwargs):
  301. FateTorchLayer.__init__(self)
  302. self.param_dict['p'] = p
  303. self.param_dict['inplace'] = inplace
  304. self.param_dict.update(kwargs)
  305. nn.modules.dropout._DropoutNd.__init__(self, **self.param_dict)
  306. class CELU(nn.modules.activation.CELU, FateTorchLayer):
  307. def __init__(self, alpha=1.0, inplace=False, **kwargs):
  308. FateTorchLayer.__init__(self)
  309. self.param_dict['alpha'] = alpha
  310. self.param_dict['inplace'] = inplace
  311. self.param_dict.update(kwargs)
  312. nn.modules.activation.CELU.__init__(self, **self.param_dict)
  313. class ELU(nn.modules.activation.ELU, FateTorchLayer):
  314. def __init__(self, alpha=1.0, inplace=False, **kwargs):
  315. FateTorchLayer.__init__(self)
  316. self.param_dict['alpha'] = alpha
  317. self.param_dict['inplace'] = inplace
  318. self.param_dict.update(kwargs)
  319. nn.modules.activation.ELU.__init__(self, **self.param_dict)
  320. class GELU(nn.modules.activation.GELU, FateTorchLayer):
  321. def __init__(self, approximate='none', **kwargs):
  322. FateTorchLayer.__init__(self)
  323. self.param_dict['approximate'] = approximate
  324. self.param_dict.update(kwargs)
  325. nn.modules.activation.GELU.__init__(self, **self.param_dict)
  326. class GLU(nn.modules.activation.GLU, FateTorchLayer):
  327. def __init__(self, dim=-1, **kwargs):
  328. FateTorchLayer.__init__(self)
  329. self.param_dict['dim'] = dim
  330. self.param_dict.update(kwargs)
  331. nn.modules.activation.GLU.__init__(self, **self.param_dict)
  332. class Hardshrink(nn.modules.activation.Hardshrink, FateTorchLayer):
  333. def __init__(self, lambd=0.5, **kwargs):
  334. FateTorchLayer.__init__(self)
  335. self.param_dict['lambd'] = lambd
  336. self.param_dict.update(kwargs)
  337. nn.modules.activation.Hardshrink.__init__(self, **self.param_dict)
  338. class Hardsigmoid(nn.modules.activation.Hardsigmoid, FateTorchLayer):
  339. def __init__(self, inplace=False, **kwargs):
  340. FateTorchLayer.__init__(self)
  341. self.param_dict['inplace'] = inplace
  342. self.param_dict.update(kwargs)
  343. nn.modules.activation.Hardsigmoid.__init__(self, **self.param_dict)
  344. class Hardswish(nn.modules.activation.Hardswish, FateTorchLayer):
  345. def __init__(self, inplace=False, **kwargs):
  346. FateTorchLayer.__init__(self)
  347. self.param_dict['inplace'] = inplace
  348. self.param_dict.update(kwargs)
  349. nn.modules.activation.Hardswish.__init__(self, **self.param_dict)
  350. class Hardtanh(nn.modules.activation.Hardtanh, FateTorchLayer):
  351. def __init__(
  352. self,
  353. min_val=-1.0,
  354. max_val=1.0,
  355. inplace=False,
  356. min_value=None,
  357. max_value=None,
  358. **kwargs):
  359. FateTorchLayer.__init__(self)
  360. self.param_dict['min_val'] = min_val
  361. self.param_dict['max_val'] = max_val
  362. self.param_dict['inplace'] = inplace
  363. self.param_dict['min_value'] = min_value
  364. self.param_dict['max_value'] = max_value
  365. self.param_dict.update(kwargs)
  366. nn.modules.activation.Hardtanh.__init__(self, **self.param_dict)
  367. class LeakyReLU(nn.modules.activation.LeakyReLU, FateTorchLayer):
  368. def __init__(self, negative_slope=0.01, inplace=False, **kwargs):
  369. FateTorchLayer.__init__(self)
  370. self.param_dict['negative_slope'] = negative_slope
  371. self.param_dict['inplace'] = inplace
  372. self.param_dict.update(kwargs)
  373. nn.modules.activation.LeakyReLU.__init__(self, **self.param_dict)
  374. class LogSigmoid(nn.modules.activation.LogSigmoid, FateTorchLayer):
  375. def __init__(self, **kwargs):
  376. FateTorchLayer.__init__(self)
  377. self.param_dict.update(kwargs)
  378. nn.modules.activation.LogSigmoid.__init__(self, **self.param_dict)
  379. class LogSoftmax(nn.modules.activation.LogSoftmax, FateTorchLayer):
  380. def __init__(self, dim=None, **kwargs):
  381. FateTorchLayer.__init__(self)
  382. self.param_dict['dim'] = dim
  383. self.param_dict.update(kwargs)
  384. nn.modules.activation.LogSoftmax.__init__(self, **self.param_dict)
  385. class Mish(nn.modules.activation.Mish, FateTorchLayer):
  386. def __init__(self, inplace=False, **kwargs):
  387. FateTorchLayer.__init__(self)
  388. self.param_dict['inplace'] = inplace
  389. self.param_dict.update(kwargs)
  390. nn.modules.activation.Mish.__init__(self, **self.param_dict)
  391. class MultiheadAttention(
  392. nn.modules.activation.MultiheadAttention,
  393. FateTorchLayer):
  394. def __init__(
  395. self,
  396. embed_dim,
  397. num_heads,
  398. dropout=0.0,
  399. bias=True,
  400. add_bias_kv=False,
  401. add_zero_attn=False,
  402. kdim=None,
  403. vdim=None,
  404. batch_first=False,
  405. device=None,
  406. dtype=None,
  407. **kwargs):
  408. FateTorchLayer.__init__(self)
  409. self.param_dict['dropout'] = dropout
  410. self.param_dict['bias'] = bias
  411. self.param_dict['add_bias_kv'] = add_bias_kv
  412. self.param_dict['add_zero_attn'] = add_zero_attn
  413. self.param_dict['kdim'] = kdim
  414. self.param_dict['vdim'] = vdim
  415. self.param_dict['batch_first'] = batch_first
  416. self.param_dict['device'] = device
  417. self.param_dict['dtype'] = dtype
  418. self.param_dict['embed_dim'] = embed_dim
  419. self.param_dict['num_heads'] = num_heads
  420. self.param_dict.update(kwargs)
  421. nn.modules.activation.MultiheadAttention.__init__(
  422. self, **self.param_dict)
  423. class PReLU(nn.modules.activation.PReLU, FateTorchLayer):
  424. def __init__(
  425. self,
  426. num_parameters=1,
  427. init=0.25,
  428. device=None,
  429. dtype=None,
  430. **kwargs):
  431. FateTorchLayer.__init__(self)
  432. self.param_dict['num_parameters'] = num_parameters
  433. self.param_dict['init'] = init
  434. self.param_dict['device'] = device
  435. self.param_dict['dtype'] = dtype
  436. self.param_dict.update(kwargs)
  437. nn.modules.activation.PReLU.__init__(self, **self.param_dict)
  438. class RReLU(nn.modules.activation.RReLU, FateTorchLayer):
  439. def __init__(
  440. self,
  441. lower=0.125,
  442. upper=0.3333333333333333,
  443. inplace=False,
  444. **kwargs):
  445. FateTorchLayer.__init__(self)
  446. self.param_dict['lower'] = lower
  447. self.param_dict['upper'] = upper
  448. self.param_dict['inplace'] = inplace
  449. self.param_dict.update(kwargs)
  450. nn.modules.activation.RReLU.__init__(self, **self.param_dict)
  451. class ReLU(nn.modules.activation.ReLU, FateTorchLayer):
  452. def __init__(self, inplace=False, **kwargs):
  453. FateTorchLayer.__init__(self)
  454. self.param_dict['inplace'] = inplace
  455. self.param_dict.update(kwargs)
  456. nn.modules.activation.ReLU.__init__(self, **self.param_dict)
  457. class ReLU6(nn.modules.activation.ReLU6, FateTorchLayer):
  458. def __init__(self, inplace=False, **kwargs):
  459. FateTorchLayer.__init__(self)
  460. self.param_dict['inplace'] = inplace
  461. self.param_dict.update(kwargs)
  462. nn.modules.activation.ReLU6.__init__(self, **self.param_dict)
  463. class SELU(nn.modules.activation.SELU, FateTorchLayer):
  464. def __init__(self, inplace=False, **kwargs):
  465. FateTorchLayer.__init__(self)
  466. self.param_dict['inplace'] = inplace
  467. self.param_dict.update(kwargs)
  468. nn.modules.activation.SELU.__init__(self, **self.param_dict)
  469. class SiLU(nn.modules.activation.SiLU, FateTorchLayer):
  470. def __init__(self, inplace=False, **kwargs):
  471. FateTorchLayer.__init__(self)
  472. self.param_dict['inplace'] = inplace
  473. self.param_dict.update(kwargs)
  474. nn.modules.activation.SiLU.__init__(self, **self.param_dict)
  475. class Sigmoid(nn.modules.activation.Sigmoid, FateTorchLayer):
  476. def __init__(self, **kwargs):
  477. FateTorchLayer.__init__(self)
  478. self.param_dict.update(kwargs)
  479. nn.modules.activation.Sigmoid.__init__(self, **self.param_dict)
  480. class Softmax(nn.modules.activation.Softmax, FateTorchLayer):
  481. def __init__(self, dim=None, **kwargs):
  482. FateTorchLayer.__init__(self)
  483. self.param_dict['dim'] = dim
  484. self.param_dict.update(kwargs)
  485. nn.modules.activation.Softmax.__init__(self, **self.param_dict)
  486. class Softmax2d(nn.modules.activation.Softmax2d, FateTorchLayer):
  487. def __init__(self, **kwargs):
  488. FateTorchLayer.__init__(self)
  489. self.param_dict.update(kwargs)
  490. nn.modules.activation.Softmax2d.__init__(self, **self.param_dict)
  491. class Softmin(nn.modules.activation.Softmin, FateTorchLayer):
  492. def __init__(self, dim=None, **kwargs):
  493. FateTorchLayer.__init__(self)
  494. self.param_dict['dim'] = dim
  495. self.param_dict.update(kwargs)
  496. nn.modules.activation.Softmin.__init__(self, **self.param_dict)
  497. class Softplus(nn.modules.activation.Softplus, FateTorchLayer):
  498. def __init__(self, beta=1, threshold=20, **kwargs):
  499. FateTorchLayer.__init__(self)
  500. self.param_dict['beta'] = beta
  501. self.param_dict['threshold'] = threshold
  502. self.param_dict.update(kwargs)
  503. nn.modules.activation.Softplus.__init__(self, **self.param_dict)
  504. class Softshrink(nn.modules.activation.Softshrink, FateTorchLayer):
  505. def __init__(self, lambd=0.5, **kwargs):
  506. FateTorchLayer.__init__(self)
  507. self.param_dict['lambd'] = lambd
  508. self.param_dict.update(kwargs)
  509. nn.modules.activation.Softshrink.__init__(self, **self.param_dict)
  510. class Softsign(nn.modules.activation.Softsign, FateTorchLayer):
  511. def __init__(self, **kwargs):
  512. FateTorchLayer.__init__(self)
  513. self.param_dict.update(kwargs)
  514. nn.modules.activation.Softsign.__init__(self, **self.param_dict)
  515. class Tanh(nn.modules.activation.Tanh, FateTorchLayer):
  516. def __init__(self, **kwargs):
  517. FateTorchLayer.__init__(self)
  518. self.param_dict.update(kwargs)
  519. nn.modules.activation.Tanh.__init__(self, **self.param_dict)
  520. class Tanhshrink(nn.modules.activation.Tanhshrink, FateTorchLayer):
  521. def __init__(self, **kwargs):
  522. FateTorchLayer.__init__(self)
  523. self.param_dict.update(kwargs)
  524. nn.modules.activation.Tanhshrink.__init__(self, **self.param_dict)
  525. class Threshold(nn.modules.activation.Threshold, FateTorchLayer):
  526. def __init__(self, threshold, value, inplace=False, **kwargs):
  527. FateTorchLayer.__init__(self)
  528. self.param_dict['inplace'] = inplace
  529. self.param_dict['threshold'] = threshold
  530. self.param_dict['value'] = value
  531. self.param_dict.update(kwargs)
  532. nn.modules.activation.Threshold.__init__(self, **self.param_dict)
  533. class Conv1d(nn.modules.conv.Conv1d, FateTorchLayer):
  534. def __init__(
  535. self,
  536. in_channels,
  537. out_channels,
  538. kernel_size,
  539. stride=1,
  540. padding=0,
  541. dilation=1,
  542. groups=1,
  543. bias=True,
  544. padding_mode='zeros',
  545. device=None,
  546. dtype=None,
  547. **kwargs):
  548. FateTorchLayer.__init__(self)
  549. self.param_dict['stride'] = stride
  550. self.param_dict['padding'] = padding
  551. self.param_dict['dilation'] = dilation
  552. self.param_dict['groups'] = groups
  553. self.param_dict['bias'] = bias
  554. self.param_dict['padding_mode'] = padding_mode
  555. self.param_dict['device'] = device
  556. self.param_dict['dtype'] = dtype
  557. self.param_dict['in_channels'] = in_channels
  558. self.param_dict['out_channels'] = out_channels
  559. self.param_dict['kernel_size'] = kernel_size
  560. self.param_dict.update(kwargs)
  561. nn.modules.conv.Conv1d.__init__(self, **self.param_dict)
  562. class Conv2d(nn.modules.conv.Conv2d, FateTorchLayer):
  563. def __init__(
  564. self,
  565. in_channels,
  566. out_channels,
  567. kernel_size,
  568. stride=1,
  569. padding=0,
  570. dilation=1,
  571. groups=1,
  572. bias=True,
  573. padding_mode='zeros',
  574. device=None,
  575. dtype=None,
  576. **kwargs):
  577. FateTorchLayer.__init__(self)
  578. self.param_dict['stride'] = stride
  579. self.param_dict['padding'] = padding
  580. self.param_dict['dilation'] = dilation
  581. self.param_dict['groups'] = groups
  582. self.param_dict['bias'] = bias
  583. self.param_dict['padding_mode'] = padding_mode
  584. self.param_dict['device'] = device
  585. self.param_dict['dtype'] = dtype
  586. self.param_dict['in_channels'] = in_channels
  587. self.param_dict['out_channels'] = out_channels
  588. self.param_dict['kernel_size'] = kernel_size
  589. self.param_dict.update(kwargs)
  590. nn.modules.conv.Conv2d.__init__(self, **self.param_dict)
  591. class Conv3d(nn.modules.conv.Conv3d, FateTorchLayer):
  592. def __init__(
  593. self,
  594. in_channels,
  595. out_channels,
  596. kernel_size,
  597. stride=1,
  598. padding=0,
  599. dilation=1,
  600. groups=1,
  601. bias=True,
  602. padding_mode='zeros',
  603. device=None,
  604. dtype=None,
  605. **kwargs):
  606. FateTorchLayer.__init__(self)
  607. self.param_dict['stride'] = stride
  608. self.param_dict['padding'] = padding
  609. self.param_dict['dilation'] = dilation
  610. self.param_dict['groups'] = groups
  611. self.param_dict['bias'] = bias
  612. self.param_dict['padding_mode'] = padding_mode
  613. self.param_dict['device'] = device
  614. self.param_dict['dtype'] = dtype
  615. self.param_dict['in_channels'] = in_channels
  616. self.param_dict['out_channels'] = out_channels
  617. self.param_dict['kernel_size'] = kernel_size
  618. self.param_dict.update(kwargs)
  619. nn.modules.conv.Conv3d.__init__(self, **self.param_dict)
  620. class ConvTranspose1d(nn.modules.conv.ConvTranspose1d, FateTorchLayer):
  621. def __init__(
  622. self,
  623. in_channels,
  624. out_channels,
  625. kernel_size,
  626. stride=1,
  627. padding=0,
  628. output_padding=0,
  629. groups=1,
  630. bias=True,
  631. dilation=1,
  632. padding_mode='zeros',
  633. device=None,
  634. dtype=None,
  635. **kwargs):
  636. FateTorchLayer.__init__(self)
  637. self.param_dict['stride'] = stride
  638. self.param_dict['padding'] = padding
  639. self.param_dict['output_padding'] = output_padding
  640. self.param_dict['groups'] = groups
  641. self.param_dict['bias'] = bias
  642. self.param_dict['dilation'] = dilation
  643. self.param_dict['padding_mode'] = padding_mode
  644. self.param_dict['device'] = device
  645. self.param_dict['dtype'] = dtype
  646. self.param_dict['in_channels'] = in_channels
  647. self.param_dict['out_channels'] = out_channels
  648. self.param_dict['kernel_size'] = kernel_size
  649. self.param_dict.update(kwargs)
  650. nn.modules.conv.ConvTranspose1d.__init__(self, **self.param_dict)
  651. class ConvTranspose2d(nn.modules.conv.ConvTranspose2d, FateTorchLayer):
  652. def __init__(
  653. self,
  654. in_channels,
  655. out_channels,
  656. kernel_size,
  657. stride=1,
  658. padding=0,
  659. output_padding=0,
  660. groups=1,
  661. bias=True,
  662. dilation=1,
  663. padding_mode='zeros',
  664. device=None,
  665. dtype=None,
  666. **kwargs):
  667. FateTorchLayer.__init__(self)
  668. self.param_dict['stride'] = stride
  669. self.param_dict['padding'] = padding
  670. self.param_dict['output_padding'] = output_padding
  671. self.param_dict['groups'] = groups
  672. self.param_dict['bias'] = bias
  673. self.param_dict['dilation'] = dilation
  674. self.param_dict['padding_mode'] = padding_mode
  675. self.param_dict['device'] = device
  676. self.param_dict['dtype'] = dtype
  677. self.param_dict['in_channels'] = in_channels
  678. self.param_dict['out_channels'] = out_channels
  679. self.param_dict['kernel_size'] = kernel_size
  680. self.param_dict.update(kwargs)
  681. nn.modules.conv.ConvTranspose2d.__init__(self, **self.param_dict)
  682. class ConvTranspose3d(nn.modules.conv.ConvTranspose3d, FateTorchLayer):
  683. def __init__(
  684. self,
  685. in_channels,
  686. out_channels,
  687. kernel_size,
  688. stride=1,
  689. padding=0,
  690. output_padding=0,
  691. groups=1,
  692. bias=True,
  693. dilation=1,
  694. padding_mode='zeros',
  695. device=None,
  696. dtype=None,
  697. **kwargs):
  698. FateTorchLayer.__init__(self)
  699. self.param_dict['stride'] = stride
  700. self.param_dict['padding'] = padding
  701. self.param_dict['output_padding'] = output_padding
  702. self.param_dict['groups'] = groups
  703. self.param_dict['bias'] = bias
  704. self.param_dict['dilation'] = dilation
  705. self.param_dict['padding_mode'] = padding_mode
  706. self.param_dict['device'] = device
  707. self.param_dict['dtype'] = dtype
  708. self.param_dict['in_channels'] = in_channels
  709. self.param_dict['out_channels'] = out_channels
  710. self.param_dict['kernel_size'] = kernel_size
  711. self.param_dict.update(kwargs)
  712. nn.modules.conv.ConvTranspose3d.__init__(self, **self.param_dict)
  713. class LazyConv1d(nn.modules.conv.LazyConv1d, FateTorchLayer):
  714. def __init__(
  715. self,
  716. out_channels,
  717. kernel_size,
  718. stride=1,
  719. padding=0,
  720. dilation=1,
  721. groups=1,
  722. bias=True,
  723. padding_mode='zeros',
  724. device=None,
  725. dtype=None,
  726. **kwargs):
  727. FateTorchLayer.__init__(self)
  728. self.param_dict['stride'] = stride
  729. self.param_dict['padding'] = padding
  730. self.param_dict['dilation'] = dilation
  731. self.param_dict['groups'] = groups
  732. self.param_dict['bias'] = bias
  733. self.param_dict['padding_mode'] = padding_mode
  734. self.param_dict['device'] = device
  735. self.param_dict['dtype'] = dtype
  736. self.param_dict['out_channels'] = out_channels
  737. self.param_dict['kernel_size'] = kernel_size
  738. self.param_dict.update(kwargs)
  739. nn.modules.conv.LazyConv1d.__init__(self, **self.param_dict)
  740. class LazyConv2d(nn.modules.conv.LazyConv2d, FateTorchLayer):
  741. def __init__(
  742. self,
  743. out_channels,
  744. kernel_size,
  745. stride=1,
  746. padding=0,
  747. dilation=1,
  748. groups=1,
  749. bias=True,
  750. padding_mode='zeros',
  751. device=None,
  752. dtype=None,
  753. **kwargs):
  754. FateTorchLayer.__init__(self)
  755. self.param_dict['stride'] = stride
  756. self.param_dict['padding'] = padding
  757. self.param_dict['dilation'] = dilation
  758. self.param_dict['groups'] = groups
  759. self.param_dict['bias'] = bias
  760. self.param_dict['padding_mode'] = padding_mode
  761. self.param_dict['device'] = device
  762. self.param_dict['dtype'] = dtype
  763. self.param_dict['out_channels'] = out_channels
  764. self.param_dict['kernel_size'] = kernel_size
  765. self.param_dict.update(kwargs)
  766. nn.modules.conv.LazyConv2d.__init__(self, **self.param_dict)
  767. class LazyConv3d(nn.modules.conv.LazyConv3d, FateTorchLayer):
  768. def __init__(
  769. self,
  770. out_channels,
  771. kernel_size,
  772. stride=1,
  773. padding=0,
  774. dilation=1,
  775. groups=1,
  776. bias=True,
  777. padding_mode='zeros',
  778. device=None,
  779. dtype=None,
  780. **kwargs):
  781. FateTorchLayer.__init__(self)
  782. self.param_dict['stride'] = stride
  783. self.param_dict['padding'] = padding
  784. self.param_dict['dilation'] = dilation
  785. self.param_dict['groups'] = groups
  786. self.param_dict['bias'] = bias
  787. self.param_dict['padding_mode'] = padding_mode
  788. self.param_dict['device'] = device
  789. self.param_dict['dtype'] = dtype
  790. self.param_dict['out_channels'] = out_channels
  791. self.param_dict['kernel_size'] = kernel_size
  792. self.param_dict.update(kwargs)
  793. nn.modules.conv.LazyConv3d.__init__(self, **self.param_dict)
  794. class LazyConvTranspose1d(nn.modules.conv.LazyConvTranspose1d, FateTorchLayer):
  795. def __init__(
  796. self,
  797. out_channels,
  798. kernel_size,
  799. stride=1,
  800. padding=0,
  801. output_padding=0,
  802. groups=1,
  803. bias=True,
  804. dilation=1,
  805. padding_mode='zeros',
  806. device=None,
  807. dtype=None,
  808. **kwargs):
  809. FateTorchLayer.__init__(self)
  810. self.param_dict['stride'] = stride
  811. self.param_dict['padding'] = padding
  812. self.param_dict['output_padding'] = output_padding
  813. self.param_dict['groups'] = groups
  814. self.param_dict['bias'] = bias
  815. self.param_dict['dilation'] = dilation
  816. self.param_dict['padding_mode'] = padding_mode
  817. self.param_dict['device'] = device
  818. self.param_dict['dtype'] = dtype
  819. self.param_dict['out_channels'] = out_channels
  820. self.param_dict['kernel_size'] = kernel_size
  821. self.param_dict.update(kwargs)
  822. nn.modules.conv.LazyConvTranspose1d.__init__(self, **self.param_dict)
  823. class LazyConvTranspose2d(nn.modules.conv.LazyConvTranspose2d, FateTorchLayer):
  824. def __init__(
  825. self,
  826. out_channels,
  827. kernel_size,
  828. stride=1,
  829. padding=0,
  830. output_padding=0,
  831. groups=1,
  832. bias=True,
  833. dilation=1,
  834. padding_mode='zeros',
  835. device=None,
  836. dtype=None,
  837. **kwargs):
  838. FateTorchLayer.__init__(self)
  839. self.param_dict['stride'] = stride
  840. self.param_dict['padding'] = padding
  841. self.param_dict['output_padding'] = output_padding
  842. self.param_dict['groups'] = groups
  843. self.param_dict['bias'] = bias
  844. self.param_dict['dilation'] = dilation
  845. self.param_dict['padding_mode'] = padding_mode
  846. self.param_dict['device'] = device
  847. self.param_dict['dtype'] = dtype
  848. self.param_dict['out_channels'] = out_channels
  849. self.param_dict['kernel_size'] = kernel_size
  850. self.param_dict.update(kwargs)
  851. nn.modules.conv.LazyConvTranspose2d.__init__(self, **self.param_dict)
  852. class LazyConvTranspose3d(nn.modules.conv.LazyConvTranspose3d, FateTorchLayer):
  853. def __init__(
  854. self,
  855. out_channels,
  856. kernel_size,
  857. stride=1,
  858. padding=0,
  859. output_padding=0,
  860. groups=1,
  861. bias=True,
  862. dilation=1,
  863. padding_mode='zeros',
  864. device=None,
  865. dtype=None,
  866. **kwargs):
  867. FateTorchLayer.__init__(self)
  868. self.param_dict['stride'] = stride
  869. self.param_dict['padding'] = padding
  870. self.param_dict['output_padding'] = output_padding
  871. self.param_dict['groups'] = groups
  872. self.param_dict['bias'] = bias
  873. self.param_dict['dilation'] = dilation
  874. self.param_dict['padding_mode'] = padding_mode
  875. self.param_dict['device'] = device
  876. self.param_dict['dtype'] = dtype
  877. self.param_dict['out_channels'] = out_channels
  878. self.param_dict['kernel_size'] = kernel_size
  879. self.param_dict.update(kwargs)
  880. nn.modules.conv.LazyConvTranspose3d.__init__(self, **self.param_dict)
  881. class _ConvNd(nn.modules.conv._ConvNd, FateTorchLayer):
  882. def __init__(
  883. self,
  884. in_channels,
  885. out_channels,
  886. kernel_size,
  887. stride,
  888. padding,
  889. dilation,
  890. transposed,
  891. output_padding,
  892. groups,
  893. bias,
  894. padding_mode,
  895. device=None,
  896. dtype=None,
  897. **kwargs):
  898. FateTorchLayer.__init__(self)
  899. self.param_dict['device'] = device
  900. self.param_dict['dtype'] = dtype
  901. self.param_dict['in_channels'] = in_channels
  902. self.param_dict['out_channels'] = out_channels
  903. self.param_dict['kernel_size'] = kernel_size
  904. self.param_dict['stride'] = stride
  905. self.param_dict['padding'] = padding
  906. self.param_dict['dilation'] = dilation
  907. self.param_dict['transposed'] = transposed
  908. self.param_dict['output_padding'] = output_padding
  909. self.param_dict['groups'] = groups
  910. self.param_dict['bias'] = bias
  911. self.param_dict['padding_mode'] = padding_mode
  912. self.param_dict.update(kwargs)
  913. nn.modules.conv._ConvNd.__init__(self, **self.param_dict)
  914. class _ConvTransposeMixin(nn.modules.conv._ConvTransposeMixin, FateTorchLayer):
  915. def __init__(self, **kwargs):
  916. FateTorchLayer.__init__(self)
  917. self.param_dict.update(kwargs)
  918. nn.modules.conv._ConvTransposeMixin.__init__(self, **self.param_dict)
  919. class _ConvTransposeNd(nn.modules.conv._ConvTransposeNd, FateTorchLayer):
  920. def __init__(
  921. self,
  922. in_channels,
  923. out_channels,
  924. kernel_size,
  925. stride,
  926. padding,
  927. dilation,
  928. transposed,
  929. output_padding,
  930. groups,
  931. bias,
  932. padding_mode,
  933. device=None,
  934. dtype=None,
  935. **kwargs):
  936. FateTorchLayer.__init__(self)
  937. self.param_dict['device'] = device
  938. self.param_dict['dtype'] = dtype
  939. self.param_dict['in_channels'] = in_channels
  940. self.param_dict['out_channels'] = out_channels
  941. self.param_dict['kernel_size'] = kernel_size
  942. self.param_dict['stride'] = stride
  943. self.param_dict['padding'] = padding
  944. self.param_dict['dilation'] = dilation
  945. self.param_dict['transposed'] = transposed
  946. self.param_dict['output_padding'] = output_padding
  947. self.param_dict['groups'] = groups
  948. self.param_dict['bias'] = bias
  949. self.param_dict['padding_mode'] = padding_mode
  950. self.param_dict.update(kwargs)
  951. nn.modules.conv._ConvTransposeNd.__init__(self, **self.param_dict)
  952. class _LazyConvXdMixin(nn.modules.conv._LazyConvXdMixin, FateTorchLayer):
  953. def __init__(self, **kwargs):
  954. FateTorchLayer.__init__(self)
  955. self.param_dict.update(kwargs)
  956. nn.modules.conv._LazyConvXdMixin.__init__(self, **self.param_dict)
  957. class Transformer(nn.modules.transformer.Transformer, FateTorchLayer):
  958. def __init__(
  959. self,
  960. d_model=512,
  961. nhead=8,
  962. num_encoder_layers=6,
  963. num_decoder_layers=6,
  964. dim_feedforward=2048,
  965. dropout=0.1,
  966. custom_encoder=None,
  967. custom_decoder=None,
  968. layer_norm_eps=1e-05,
  969. batch_first=False,
  970. norm_first=False,
  971. device=None,
  972. dtype=None,
  973. **kwargs):
  974. FateTorchLayer.__init__(self)
  975. self.param_dict['d_model'] = d_model
  976. self.param_dict['nhead'] = nhead
  977. self.param_dict['num_encoder_layers'] = num_encoder_layers
  978. self.param_dict['num_decoder_layers'] = num_decoder_layers
  979. self.param_dict['dim_feedforward'] = dim_feedforward
  980. self.param_dict['dropout'] = dropout
  981. self.param_dict['custom_encoder'] = custom_encoder
  982. self.param_dict['custom_decoder'] = custom_decoder
  983. self.param_dict['layer_norm_eps'] = layer_norm_eps
  984. self.param_dict['batch_first'] = batch_first
  985. self.param_dict['norm_first'] = norm_first
  986. self.param_dict['device'] = device
  987. self.param_dict['dtype'] = dtype
  988. self.param_dict.update(kwargs)
  989. nn.modules.transformer.Transformer.__init__(self, **self.param_dict)
  990. class TransformerDecoder(
  991. nn.modules.transformer.TransformerDecoder,
  992. FateTorchLayer):
  993. def __init__(self, decoder_layer, num_layers, norm=None, **kwargs):
  994. FateTorchLayer.__init__(self)
  995. self.param_dict['norm'] = norm
  996. self.param_dict['decoder_layer'] = decoder_layer
  997. self.param_dict['num_layers'] = num_layers
  998. self.param_dict.update(kwargs)
  999. nn.modules.transformer.TransformerDecoder.__init__(
  1000. self, **self.param_dict)
  1001. class TransformerDecoderLayer(
  1002. nn.modules.transformer.TransformerDecoderLayer,
  1003. FateTorchLayer):
  1004. def __init__(
  1005. self,
  1006. d_model,
  1007. nhead,
  1008. dim_feedforward=2048,
  1009. dropout=0.1,
  1010. layer_norm_eps=1e-05,
  1011. batch_first=False,
  1012. norm_first=False,
  1013. device=None,
  1014. dtype=None,
  1015. **kwargs):
  1016. FateTorchLayer.__init__(self)
  1017. self.param_dict['dim_feedforward'] = dim_feedforward
  1018. self.param_dict['dropout'] = dropout
  1019. self.param_dict['layer_norm_eps'] = layer_norm_eps
  1020. self.param_dict['batch_first'] = batch_first
  1021. self.param_dict['norm_first'] = norm_first
  1022. self.param_dict['device'] = device
  1023. self.param_dict['dtype'] = dtype
  1024. self.param_dict['d_model'] = d_model
  1025. self.param_dict['nhead'] = nhead
  1026. self.param_dict.update(kwargs)
  1027. nn.modules.transformer.TransformerDecoderLayer.__init__(
  1028. self, **self.param_dict)
  1029. class TransformerEncoder(
  1030. nn.modules.transformer.TransformerEncoder,
  1031. FateTorchLayer):
  1032. def __init__(
  1033. self,
  1034. encoder_layer,
  1035. num_layers,
  1036. norm=None,
  1037. enable_nested_tensor=False,
  1038. **kwargs):
  1039. FateTorchLayer.__init__(self)
  1040. self.param_dict['norm'] = norm
  1041. self.param_dict['enable_nested_tensor'] = enable_nested_tensor
  1042. self.param_dict['encoder_layer'] = encoder_layer
  1043. self.param_dict['num_layers'] = num_layers
  1044. self.param_dict.update(kwargs)
  1045. nn.modules.transformer.TransformerEncoder.__init__(
  1046. self, **self.param_dict)
  1047. class TransformerEncoderLayer(
  1048. nn.modules.transformer.TransformerEncoderLayer,
  1049. FateTorchLayer):
  1050. def __init__(
  1051. self,
  1052. d_model,
  1053. nhead,
  1054. dim_feedforward=2048,
  1055. dropout=0.1,
  1056. layer_norm_eps=1e-05,
  1057. batch_first=False,
  1058. norm_first=False,
  1059. device=None,
  1060. dtype=None,
  1061. **kwargs):
  1062. FateTorchLayer.__init__(self)
  1063. self.param_dict['dim_feedforward'] = dim_feedforward
  1064. self.param_dict['dropout'] = dropout
  1065. self.param_dict['layer_norm_eps'] = layer_norm_eps
  1066. self.param_dict['batch_first'] = batch_first
  1067. self.param_dict['norm_first'] = norm_first
  1068. self.param_dict['device'] = device
  1069. self.param_dict['dtype'] = dtype
  1070. self.param_dict['d_model'] = d_model
  1071. self.param_dict['nhead'] = nhead
  1072. self.param_dict.update(kwargs)
  1073. nn.modules.transformer.TransformerEncoderLayer.__init__(
  1074. self, **self.param_dict)
  1075. class AdaptiveAvgPool1d(nn.modules.pooling.AdaptiveAvgPool1d, FateTorchLayer):
  1076. def __init__(self, output_size, **kwargs):
  1077. FateTorchLayer.__init__(self)
  1078. self.param_dict['output_size'] = output_size
  1079. self.param_dict.update(kwargs)
  1080. nn.modules.pooling.AdaptiveAvgPool1d.__init__(self, **self.param_dict)
  1081. class AdaptiveAvgPool2d(nn.modules.pooling.AdaptiveAvgPool2d, FateTorchLayer):
  1082. def __init__(self, output_size, **kwargs):
  1083. FateTorchLayer.__init__(self)
  1084. self.param_dict['output_size'] = output_size
  1085. self.param_dict.update(kwargs)
  1086. nn.modules.pooling.AdaptiveAvgPool2d.__init__(self, **self.param_dict)
  1087. class AdaptiveAvgPool3d(nn.modules.pooling.AdaptiveAvgPool3d, FateTorchLayer):
  1088. def __init__(self, output_size, **kwargs):
  1089. FateTorchLayer.__init__(self)
  1090. self.param_dict['output_size'] = output_size
  1091. self.param_dict.update(kwargs)
  1092. nn.modules.pooling.AdaptiveAvgPool3d.__init__(self, **self.param_dict)
  1093. class AdaptiveMaxPool1d(nn.modules.pooling.AdaptiveMaxPool1d, FateTorchLayer):
  1094. def __init__(self, output_size, return_indices=False, **kwargs):
  1095. FateTorchLayer.__init__(self)
  1096. self.param_dict['return_indices'] = return_indices
  1097. self.param_dict['output_size'] = output_size
  1098. self.param_dict.update(kwargs)
  1099. nn.modules.pooling.AdaptiveMaxPool1d.__init__(self, **self.param_dict)
  1100. class AdaptiveMaxPool2d(nn.modules.pooling.AdaptiveMaxPool2d, FateTorchLayer):
  1101. def __init__(self, output_size, return_indices=False, **kwargs):
  1102. FateTorchLayer.__init__(self)
  1103. self.param_dict['return_indices'] = return_indices
  1104. self.param_dict['output_size'] = output_size
  1105. self.param_dict.update(kwargs)
  1106. nn.modules.pooling.AdaptiveMaxPool2d.__init__(self, **self.param_dict)
  1107. class AdaptiveMaxPool3d(nn.modules.pooling.AdaptiveMaxPool3d, FateTorchLayer):
  1108. def __init__(self, output_size, return_indices=False, **kwargs):
  1109. FateTorchLayer.__init__(self)
  1110. self.param_dict['return_indices'] = return_indices
  1111. self.param_dict['output_size'] = output_size
  1112. self.param_dict.update(kwargs)
  1113. nn.modules.pooling.AdaptiveMaxPool3d.__init__(self, **self.param_dict)
  1114. class AvgPool1d(nn.modules.pooling.AvgPool1d, FateTorchLayer):
  1115. def __init__(
  1116. self,
  1117. kernel_size,
  1118. stride=None,
  1119. padding=0,
  1120. ceil_mode=False,
  1121. count_include_pad=True,
  1122. **kwargs):
  1123. FateTorchLayer.__init__(self)
  1124. self.param_dict['stride'] = stride
  1125. self.param_dict['padding'] = padding
  1126. self.param_dict['ceil_mode'] = ceil_mode
  1127. self.param_dict['count_include_pad'] = count_include_pad
  1128. self.param_dict['kernel_size'] = kernel_size
  1129. self.param_dict.update(kwargs)
  1130. nn.modules.pooling.AvgPool1d.__init__(self, **self.param_dict)
  1131. class AvgPool2d(nn.modules.pooling.AvgPool2d, FateTorchLayer):
  1132. def __init__(
  1133. self,
  1134. kernel_size,
  1135. stride=None,
  1136. padding=0,
  1137. ceil_mode=False,
  1138. count_include_pad=True,
  1139. divisor_override=None,
  1140. **kwargs):
  1141. FateTorchLayer.__init__(self)
  1142. self.param_dict['stride'] = stride
  1143. self.param_dict['padding'] = padding
  1144. self.param_dict['ceil_mode'] = ceil_mode
  1145. self.param_dict['count_include_pad'] = count_include_pad
  1146. self.param_dict['divisor_override'] = divisor_override
  1147. self.param_dict['kernel_size'] = kernel_size
  1148. self.param_dict.update(kwargs)
  1149. nn.modules.pooling.AvgPool2d.__init__(self, **self.param_dict)
  1150. class AvgPool3d(nn.modules.pooling.AvgPool3d, FateTorchLayer):
  1151. def __init__(
  1152. self,
  1153. kernel_size,
  1154. stride=None,
  1155. padding=0,
  1156. ceil_mode=False,
  1157. count_include_pad=True,
  1158. divisor_override=None,
  1159. **kwargs):
  1160. FateTorchLayer.__init__(self)
  1161. self.param_dict['stride'] = stride
  1162. self.param_dict['padding'] = padding
  1163. self.param_dict['ceil_mode'] = ceil_mode
  1164. self.param_dict['count_include_pad'] = count_include_pad
  1165. self.param_dict['divisor_override'] = divisor_override
  1166. self.param_dict['kernel_size'] = kernel_size
  1167. self.param_dict.update(kwargs)
  1168. nn.modules.pooling.AvgPool3d.__init__(self, **self.param_dict)
  1169. class FractionalMaxPool2d(
  1170. nn.modules.pooling.FractionalMaxPool2d,
  1171. FateTorchLayer):
  1172. def __init__(
  1173. self,
  1174. kernel_size,
  1175. output_size=None,
  1176. output_ratio=None,
  1177. return_indices=False,
  1178. _random_samples=None,
  1179. **kwargs):
  1180. FateTorchLayer.__init__(self)
  1181. self.param_dict['output_size'] = output_size
  1182. self.param_dict['output_ratio'] = output_ratio
  1183. self.param_dict['return_indices'] = return_indices
  1184. self.param_dict['_random_samples'] = _random_samples
  1185. self.param_dict['kernel_size'] = kernel_size
  1186. self.param_dict.update(kwargs)
  1187. nn.modules.pooling.FractionalMaxPool2d.__init__(
  1188. self, **self.param_dict)
  1189. class FractionalMaxPool3d(
  1190. nn.modules.pooling.FractionalMaxPool3d,
  1191. FateTorchLayer):
  1192. def __init__(
  1193. self,
  1194. kernel_size,
  1195. output_size=None,
  1196. output_ratio=None,
  1197. return_indices=False,
  1198. _random_samples=None,
  1199. **kwargs):
  1200. FateTorchLayer.__init__(self)
  1201. self.param_dict['output_size'] = output_size
  1202. self.param_dict['output_ratio'] = output_ratio
  1203. self.param_dict['return_indices'] = return_indices
  1204. self.param_dict['_random_samples'] = _random_samples
  1205. self.param_dict['kernel_size'] = kernel_size
  1206. self.param_dict.update(kwargs)
  1207. nn.modules.pooling.FractionalMaxPool3d.__init__(
  1208. self, **self.param_dict)
  1209. class LPPool1d(nn.modules.pooling.LPPool1d, FateTorchLayer):
  1210. def __init__(
  1211. self,
  1212. norm_type,
  1213. kernel_size,
  1214. stride=None,
  1215. ceil_mode=False,
  1216. **kwargs):
  1217. FateTorchLayer.__init__(self)
  1218. self.param_dict['stride'] = stride
  1219. self.param_dict['ceil_mode'] = ceil_mode
  1220. self.param_dict['norm_type'] = norm_type
  1221. self.param_dict['kernel_size'] = kernel_size
  1222. self.param_dict.update(kwargs)
  1223. nn.modules.pooling.LPPool1d.__init__(self, **self.param_dict)
  1224. class LPPool2d(nn.modules.pooling.LPPool2d, FateTorchLayer):
  1225. def __init__(
  1226. self,
  1227. norm_type,
  1228. kernel_size,
  1229. stride=None,
  1230. ceil_mode=False,
  1231. **kwargs):
  1232. FateTorchLayer.__init__(self)
  1233. self.param_dict['stride'] = stride
  1234. self.param_dict['ceil_mode'] = ceil_mode
  1235. self.param_dict['norm_type'] = norm_type
  1236. self.param_dict['kernel_size'] = kernel_size
  1237. self.param_dict.update(kwargs)
  1238. nn.modules.pooling.LPPool2d.__init__(self, **self.param_dict)
  1239. class MaxPool1d(nn.modules.pooling.MaxPool1d, FateTorchLayer):
  1240. def __init__(
  1241. self,
  1242. kernel_size,
  1243. stride=None,
  1244. padding=0,
  1245. dilation=1,
  1246. return_indices=False,
  1247. ceil_mode=False,
  1248. **kwargs):
  1249. FateTorchLayer.__init__(self)
  1250. self.param_dict['stride'] = stride
  1251. self.param_dict['padding'] = padding
  1252. self.param_dict['dilation'] = dilation
  1253. self.param_dict['return_indices'] = return_indices
  1254. self.param_dict['ceil_mode'] = ceil_mode
  1255. self.param_dict['kernel_size'] = kernel_size
  1256. self.param_dict.update(kwargs)
  1257. nn.modules.pooling.MaxPool1d.__init__(self, **self.param_dict)
  1258. class MaxPool2d(nn.modules.pooling.MaxPool2d, FateTorchLayer):
  1259. def __init__(
  1260. self,
  1261. kernel_size,
  1262. stride=None,
  1263. padding=0,
  1264. dilation=1,
  1265. return_indices=False,
  1266. ceil_mode=False,
  1267. **kwargs):
  1268. FateTorchLayer.__init__(self)
  1269. self.param_dict['stride'] = stride
  1270. self.param_dict['padding'] = padding
  1271. self.param_dict['dilation'] = dilation
  1272. self.param_dict['return_indices'] = return_indices
  1273. self.param_dict['ceil_mode'] = ceil_mode
  1274. self.param_dict['kernel_size'] = kernel_size
  1275. self.param_dict.update(kwargs)
  1276. nn.modules.pooling.MaxPool2d.__init__(self, **self.param_dict)
  1277. class MaxPool3d(nn.modules.pooling.MaxPool3d, FateTorchLayer):
  1278. def __init__(
  1279. self,
  1280. kernel_size,
  1281. stride=None,
  1282. padding=0,
  1283. dilation=1,
  1284. return_indices=False,
  1285. ceil_mode=False,
  1286. **kwargs):
  1287. FateTorchLayer.__init__(self)
  1288. self.param_dict['stride'] = stride
  1289. self.param_dict['padding'] = padding
  1290. self.param_dict['dilation'] = dilation
  1291. self.param_dict['return_indices'] = return_indices
  1292. self.param_dict['ceil_mode'] = ceil_mode
  1293. self.param_dict['kernel_size'] = kernel_size
  1294. self.param_dict.update(kwargs)
  1295. nn.modules.pooling.MaxPool3d.__init__(self, **self.param_dict)
  1296. class MaxUnpool1d(nn.modules.pooling.MaxUnpool1d, FateTorchLayer):
  1297. def __init__(self, kernel_size, stride=None, padding=0, **kwargs):
  1298. FateTorchLayer.__init__(self)
  1299. self.param_dict['stride'] = stride
  1300. self.param_dict['padding'] = padding
  1301. self.param_dict['kernel_size'] = kernel_size
  1302. self.param_dict.update(kwargs)
  1303. nn.modules.pooling.MaxUnpool1d.__init__(self, **self.param_dict)
  1304. class MaxUnpool2d(nn.modules.pooling.MaxUnpool2d, FateTorchLayer):
  1305. def __init__(self, kernel_size, stride=None, padding=0, **kwargs):
  1306. FateTorchLayer.__init__(self)
  1307. self.param_dict['stride'] = stride
  1308. self.param_dict['padding'] = padding
  1309. self.param_dict['kernel_size'] = kernel_size
  1310. self.param_dict.update(kwargs)
  1311. nn.modules.pooling.MaxUnpool2d.__init__(self, **self.param_dict)
  1312. class MaxUnpool3d(nn.modules.pooling.MaxUnpool3d, FateTorchLayer):
  1313. def __init__(self, kernel_size, stride=None, padding=0, **kwargs):
  1314. FateTorchLayer.__init__(self)
  1315. self.param_dict['stride'] = stride
  1316. self.param_dict['padding'] = padding
  1317. self.param_dict['kernel_size'] = kernel_size
  1318. self.param_dict.update(kwargs)
  1319. nn.modules.pooling.MaxUnpool3d.__init__(self, **self.param_dict)
  1320. class _AdaptiveAvgPoolNd(
  1321. nn.modules.pooling._AdaptiveAvgPoolNd,
  1322. FateTorchLayer):
  1323. def __init__(self, output_size, **kwargs):
  1324. FateTorchLayer.__init__(self)
  1325. self.param_dict['output_size'] = output_size
  1326. self.param_dict.update(kwargs)
  1327. nn.modules.pooling._AdaptiveAvgPoolNd.__init__(self, **self.param_dict)
  1328. class _AdaptiveMaxPoolNd(
  1329. nn.modules.pooling._AdaptiveMaxPoolNd,
  1330. FateTorchLayer):
  1331. def __init__(self, output_size, return_indices=False, **kwargs):
  1332. FateTorchLayer.__init__(self)
  1333. self.param_dict['return_indices'] = return_indices
  1334. self.param_dict['output_size'] = output_size
  1335. self.param_dict.update(kwargs)
  1336. nn.modules.pooling._AdaptiveMaxPoolNd.__init__(self, **self.param_dict)
  1337. class _AvgPoolNd(nn.modules.pooling._AvgPoolNd, FateTorchLayer):
  1338. def __init__(self, **kwargs):
  1339. FateTorchLayer.__init__(self)
  1340. self.param_dict.update(kwargs)
  1341. nn.modules.pooling._AvgPoolNd.__init__(self, **self.param_dict)
  1342. class _LPPoolNd(nn.modules.pooling._LPPoolNd, FateTorchLayer):
  1343. def __init__(
  1344. self,
  1345. norm_type,
  1346. kernel_size,
  1347. stride=None,
  1348. ceil_mode=False,
  1349. **kwargs):
  1350. FateTorchLayer.__init__(self)
  1351. self.param_dict['stride'] = stride
  1352. self.param_dict['ceil_mode'] = ceil_mode
  1353. self.param_dict['norm_type'] = norm_type
  1354. self.param_dict['kernel_size'] = kernel_size
  1355. self.param_dict.update(kwargs)
  1356. nn.modules.pooling._LPPoolNd.__init__(self, **self.param_dict)
  1357. class _MaxPoolNd(nn.modules.pooling._MaxPoolNd, FateTorchLayer):
  1358. def __init__(
  1359. self,
  1360. kernel_size,
  1361. stride=None,
  1362. padding=0,
  1363. dilation=1,
  1364. return_indices=False,
  1365. ceil_mode=False,
  1366. **kwargs):
  1367. FateTorchLayer.__init__(self)
  1368. self.param_dict['stride'] = stride
  1369. self.param_dict['padding'] = padding
  1370. self.param_dict['dilation'] = dilation
  1371. self.param_dict['return_indices'] = return_indices
  1372. self.param_dict['ceil_mode'] = ceil_mode
  1373. self.param_dict['kernel_size'] = kernel_size
  1374. self.param_dict.update(kwargs)
  1375. nn.modules.pooling._MaxPoolNd.__init__(self, **self.param_dict)
  1376. class _MaxUnpoolNd(nn.modules.pooling._MaxUnpoolNd, FateTorchLayer):
  1377. def __init__(self, **kwargs):
  1378. FateTorchLayer.__init__(self)
  1379. self.param_dict.update(kwargs)
  1380. nn.modules.pooling._MaxUnpoolNd.__init__(self, **self.param_dict)
  1381. class BatchNorm1d(nn.modules.batchnorm.BatchNorm1d, FateTorchLayer):
  1382. def __init__(
  1383. self,
  1384. num_features,
  1385. eps=1e-05,
  1386. momentum=0.1,
  1387. affine=True,
  1388. track_running_stats=True,
  1389. device=None,
  1390. dtype=None,
  1391. **kwargs):
  1392. FateTorchLayer.__init__(self)
  1393. self.param_dict['eps'] = eps
  1394. self.param_dict['momentum'] = momentum
  1395. self.param_dict['affine'] = affine
  1396. self.param_dict['track_running_stats'] = track_running_stats
  1397. self.param_dict['device'] = device
  1398. self.param_dict['dtype'] = dtype
  1399. self.param_dict['num_features'] = num_features
  1400. self.param_dict.update(kwargs)
  1401. nn.modules.batchnorm.BatchNorm1d.__init__(self, **self.param_dict)
  1402. class BatchNorm2d(nn.modules.batchnorm.BatchNorm2d, FateTorchLayer):
  1403. def __init__(
  1404. self,
  1405. num_features,
  1406. eps=1e-05,
  1407. momentum=0.1,
  1408. affine=True,
  1409. track_running_stats=True,
  1410. device=None,
  1411. dtype=None,
  1412. **kwargs):
  1413. FateTorchLayer.__init__(self)
  1414. self.param_dict['eps'] = eps
  1415. self.param_dict['momentum'] = momentum
  1416. self.param_dict['affine'] = affine
  1417. self.param_dict['track_running_stats'] = track_running_stats
  1418. self.param_dict['device'] = device
  1419. self.param_dict['dtype'] = dtype
  1420. self.param_dict['num_features'] = num_features
  1421. self.param_dict.update(kwargs)
  1422. nn.modules.batchnorm.BatchNorm2d.__init__(self, **self.param_dict)
  1423. class BatchNorm3d(nn.modules.batchnorm.BatchNorm3d, FateTorchLayer):
  1424. def __init__(
  1425. self,
  1426. num_features,
  1427. eps=1e-05,
  1428. momentum=0.1,
  1429. affine=True,
  1430. track_running_stats=True,
  1431. device=None,
  1432. dtype=None,
  1433. **kwargs):
  1434. FateTorchLayer.__init__(self)
  1435. self.param_dict['eps'] = eps
  1436. self.param_dict['momentum'] = momentum
  1437. self.param_dict['affine'] = affine
  1438. self.param_dict['track_running_stats'] = track_running_stats
  1439. self.param_dict['device'] = device
  1440. self.param_dict['dtype'] = dtype
  1441. self.param_dict['num_features'] = num_features
  1442. self.param_dict.update(kwargs)
  1443. nn.modules.batchnorm.BatchNorm3d.__init__(self, **self.param_dict)
  1444. class LazyBatchNorm1d(nn.modules.batchnorm.LazyBatchNorm1d, FateTorchLayer):
  1445. def __init__(
  1446. self,
  1447. eps=1e-05,
  1448. momentum=0.1,
  1449. affine=True,
  1450. track_running_stats=True,
  1451. device=None,
  1452. dtype=None,
  1453. **kwargs):
  1454. FateTorchLayer.__init__(self)
  1455. self.param_dict['eps'] = eps
  1456. self.param_dict['momentum'] = momentum
  1457. self.param_dict['affine'] = affine
  1458. self.param_dict['track_running_stats'] = track_running_stats
  1459. self.param_dict['device'] = device
  1460. self.param_dict['dtype'] = dtype
  1461. self.param_dict.update(kwargs)
  1462. nn.modules.batchnorm.LazyBatchNorm1d.__init__(self, **self.param_dict)
  1463. class LazyBatchNorm2d(nn.modules.batchnorm.LazyBatchNorm2d, FateTorchLayer):
  1464. def __init__(
  1465. self,
  1466. eps=1e-05,
  1467. momentum=0.1,
  1468. affine=True,
  1469. track_running_stats=True,
  1470. device=None,
  1471. dtype=None,
  1472. **kwargs):
  1473. FateTorchLayer.__init__(self)
  1474. self.param_dict['eps'] = eps
  1475. self.param_dict['momentum'] = momentum
  1476. self.param_dict['affine'] = affine
  1477. self.param_dict['track_running_stats'] = track_running_stats
  1478. self.param_dict['device'] = device
  1479. self.param_dict['dtype'] = dtype
  1480. self.param_dict.update(kwargs)
  1481. nn.modules.batchnorm.LazyBatchNorm2d.__init__(self, **self.param_dict)
  1482. class LazyBatchNorm3d(nn.modules.batchnorm.LazyBatchNorm3d, FateTorchLayer):
  1483. def __init__(
  1484. self,
  1485. eps=1e-05,
  1486. momentum=0.1,
  1487. affine=True,
  1488. track_running_stats=True,
  1489. device=None,
  1490. dtype=None,
  1491. **kwargs):
  1492. FateTorchLayer.__init__(self)
  1493. self.param_dict['eps'] = eps
  1494. self.param_dict['momentum'] = momentum
  1495. self.param_dict['affine'] = affine
  1496. self.param_dict['track_running_stats'] = track_running_stats
  1497. self.param_dict['device'] = device
  1498. self.param_dict['dtype'] = dtype
  1499. self.param_dict.update(kwargs)
  1500. nn.modules.batchnorm.LazyBatchNorm3d.__init__(self, **self.param_dict)
  1501. class SyncBatchNorm(nn.modules.batchnorm.SyncBatchNorm, FateTorchLayer):
  1502. def __init__(
  1503. self,
  1504. num_features,
  1505. eps=1e-05,
  1506. momentum=0.1,
  1507. affine=True,
  1508. track_running_stats=True,
  1509. process_group=None,
  1510. device=None,
  1511. dtype=None,
  1512. **kwargs):
  1513. FateTorchLayer.__init__(self)
  1514. self.param_dict['eps'] = eps
  1515. self.param_dict['momentum'] = momentum
  1516. self.param_dict['affine'] = affine
  1517. self.param_dict['track_running_stats'] = track_running_stats
  1518. self.param_dict['process_group'] = process_group
  1519. self.param_dict['device'] = device
  1520. self.param_dict['dtype'] = dtype
  1521. self.param_dict['num_features'] = num_features
  1522. self.param_dict.update(kwargs)
  1523. nn.modules.batchnorm.SyncBatchNorm.__init__(self, **self.param_dict)
  1524. class _BatchNorm(nn.modules.batchnorm._BatchNorm, FateTorchLayer):
  1525. def __init__(
  1526. self,
  1527. num_features,
  1528. eps=1e-05,
  1529. momentum=0.1,
  1530. affine=True,
  1531. track_running_stats=True,
  1532. device=None,
  1533. dtype=None,
  1534. **kwargs):
  1535. FateTorchLayer.__init__(self)
  1536. self.param_dict['eps'] = eps
  1537. self.param_dict['momentum'] = momentum
  1538. self.param_dict['affine'] = affine
  1539. self.param_dict['track_running_stats'] = track_running_stats
  1540. self.param_dict['device'] = device
  1541. self.param_dict['dtype'] = dtype
  1542. self.param_dict['num_features'] = num_features
  1543. self.param_dict.update(kwargs)
  1544. nn.modules.batchnorm._BatchNorm.__init__(self, **self.param_dict)
  1545. class _LazyNormBase(nn.modules.batchnorm._LazyNormBase, FateTorchLayer):
  1546. def __init__(
  1547. self,
  1548. eps=1e-05,
  1549. momentum=0.1,
  1550. affine=True,
  1551. track_running_stats=True,
  1552. device=None,
  1553. dtype=None,
  1554. **kwargs):
  1555. FateTorchLayer.__init__(self)
  1556. self.param_dict['eps'] = eps
  1557. self.param_dict['momentum'] = momentum
  1558. self.param_dict['affine'] = affine
  1559. self.param_dict['track_running_stats'] = track_running_stats
  1560. self.param_dict['device'] = device
  1561. self.param_dict['dtype'] = dtype
  1562. self.param_dict.update(kwargs)
  1563. nn.modules.batchnorm._LazyNormBase.__init__(self, **self.param_dict)
  1564. class _NormBase(nn.modules.batchnorm._NormBase, FateTorchLayer):
  1565. def __init__(
  1566. self,
  1567. num_features,
  1568. eps=1e-05,
  1569. momentum=0.1,
  1570. affine=True,
  1571. track_running_stats=True,
  1572. device=None,
  1573. dtype=None,
  1574. **kwargs):
  1575. FateTorchLayer.__init__(self)
  1576. self.param_dict['eps'] = eps
  1577. self.param_dict['momentum'] = momentum
  1578. self.param_dict['affine'] = affine
  1579. self.param_dict['track_running_stats'] = track_running_stats
  1580. self.param_dict['device'] = device
  1581. self.param_dict['dtype'] = dtype
  1582. self.param_dict['num_features'] = num_features
  1583. self.param_dict.update(kwargs)
  1584. nn.modules.batchnorm._NormBase.__init__(self, **self.param_dict)
  1585. class ConstantPad1d(nn.modules.padding.ConstantPad1d, FateTorchLayer):
  1586. def __init__(self, padding, value, **kwargs):
  1587. FateTorchLayer.__init__(self)
  1588. self.param_dict['padding'] = padding
  1589. self.param_dict['value'] = value
  1590. self.param_dict.update(kwargs)
  1591. nn.modules.padding.ConstantPad1d.__init__(self, **self.param_dict)
  1592. class ConstantPad2d(nn.modules.padding.ConstantPad2d, FateTorchLayer):
  1593. def __init__(self, padding, value, **kwargs):
  1594. FateTorchLayer.__init__(self)
  1595. self.param_dict['padding'] = padding
  1596. self.param_dict['value'] = value
  1597. self.param_dict.update(kwargs)
  1598. nn.modules.padding.ConstantPad2d.__init__(self, **self.param_dict)
  1599. class ConstantPad3d(nn.modules.padding.ConstantPad3d, FateTorchLayer):
  1600. def __init__(self, padding, value, **kwargs):
  1601. FateTorchLayer.__init__(self)
  1602. self.param_dict['padding'] = padding
  1603. self.param_dict['value'] = value
  1604. self.param_dict.update(kwargs)
  1605. nn.modules.padding.ConstantPad3d.__init__(self, **self.param_dict)
  1606. class ReflectionPad1d(nn.modules.padding.ReflectionPad1d, FateTorchLayer):
  1607. def __init__(self, padding, **kwargs):
  1608. FateTorchLayer.__init__(self)
  1609. self.param_dict['padding'] = padding
  1610. self.param_dict.update(kwargs)
  1611. nn.modules.padding.ReflectionPad1d.__init__(self, **self.param_dict)
  1612. class ReflectionPad2d(nn.modules.padding.ReflectionPad2d, FateTorchLayer):
  1613. def __init__(self, padding, **kwargs):
  1614. FateTorchLayer.__init__(self)
  1615. self.param_dict['padding'] = padding
  1616. self.param_dict.update(kwargs)
  1617. nn.modules.padding.ReflectionPad2d.__init__(self, **self.param_dict)
  1618. class ReflectionPad3d(nn.modules.padding.ReflectionPad3d, FateTorchLayer):
  1619. def __init__(self, padding, **kwargs):
  1620. FateTorchLayer.__init__(self)
  1621. self.param_dict['padding'] = padding
  1622. self.param_dict.update(kwargs)
  1623. nn.modules.padding.ReflectionPad3d.__init__(self, **self.param_dict)
  1624. class ReplicationPad1d(nn.modules.padding.ReplicationPad1d, FateTorchLayer):
  1625. def __init__(self, padding, **kwargs):
  1626. FateTorchLayer.__init__(self)
  1627. self.param_dict['padding'] = padding
  1628. self.param_dict.update(kwargs)
  1629. nn.modules.padding.ReplicationPad1d.__init__(self, **self.param_dict)
  1630. class ReplicationPad2d(nn.modules.padding.ReplicationPad2d, FateTorchLayer):
  1631. def __init__(self, padding, **kwargs):
  1632. FateTorchLayer.__init__(self)
  1633. self.param_dict['padding'] = padding
  1634. self.param_dict.update(kwargs)
  1635. nn.modules.padding.ReplicationPad2d.__init__(self, **self.param_dict)
  1636. class ReplicationPad3d(nn.modules.padding.ReplicationPad3d, FateTorchLayer):
  1637. def __init__(self, padding, **kwargs):
  1638. FateTorchLayer.__init__(self)
  1639. self.param_dict['padding'] = padding
  1640. self.param_dict.update(kwargs)
  1641. nn.modules.padding.ReplicationPad3d.__init__(self, **self.param_dict)
  1642. class ZeroPad2d(nn.modules.padding.ZeroPad2d, FateTorchLayer):
  1643. def __init__(self, padding, **kwargs):
  1644. FateTorchLayer.__init__(self)
  1645. self.param_dict['padding'] = padding
  1646. self.param_dict.update(kwargs)
  1647. nn.modules.padding.ZeroPad2d.__init__(self, **self.param_dict)
  1648. class _ConstantPadNd(nn.modules.padding._ConstantPadNd, FateTorchLayer):
  1649. def __init__(self, value, **kwargs):
  1650. FateTorchLayer.__init__(self)
  1651. self.param_dict['value'] = value
  1652. self.param_dict.update(kwargs)
  1653. nn.modules.padding._ConstantPadNd.__init__(self, **self.param_dict)
  1654. class _ReflectionPadNd(nn.modules.padding._ReflectionPadNd, FateTorchLayer):
  1655. def __init__(self, **kwargs):
  1656. FateTorchLayer.__init__(self)
  1657. self.param_dict.update(kwargs)
  1658. nn.modules.padding._ReflectionPadNd.__init__(self, **self.param_dict)
  1659. class _ReplicationPadNd(nn.modules.padding._ReplicationPadNd, FateTorchLayer):
  1660. def __init__(self, **kwargs):
  1661. FateTorchLayer.__init__(self)
  1662. self.param_dict.update(kwargs)
  1663. nn.modules.padding._ReplicationPadNd.__init__(self, **self.param_dict)
  1664. class BCELoss(nn.modules.loss.BCELoss, FateTorchLoss):
  1665. def __init__(
  1666. self,
  1667. weight=None,
  1668. size_average=None,
  1669. reduce=None,
  1670. reduction='mean',
  1671. **kwargs):
  1672. FateTorchLoss.__init__(self)
  1673. self.param_dict['weight'] = weight
  1674. self.param_dict['size_average'] = size_average
  1675. self.param_dict['reduce'] = reduce
  1676. self.param_dict['reduction'] = reduction
  1677. self.param_dict.update(kwargs)
  1678. nn.modules.loss.BCELoss.__init__(self, **self.param_dict)
  1679. class BCEWithLogitsLoss(nn.modules.loss.BCEWithLogitsLoss, FateTorchLoss):
  1680. def __init__(
  1681. self,
  1682. weight=None,
  1683. size_average=None,
  1684. reduce=None,
  1685. reduction='mean',
  1686. pos_weight=None,
  1687. **kwargs):
  1688. FateTorchLoss.__init__(self)
  1689. self.param_dict['weight'] = weight
  1690. self.param_dict['size_average'] = size_average
  1691. self.param_dict['reduce'] = reduce
  1692. self.param_dict['reduction'] = reduction
  1693. self.param_dict['pos_weight'] = pos_weight
  1694. self.param_dict.update(kwargs)
  1695. nn.modules.loss.BCEWithLogitsLoss.__init__(self, **self.param_dict)
  1696. class CTCLoss(nn.modules.loss.CTCLoss, FateTorchLoss):
  1697. def __init__(
  1698. self,
  1699. blank=0,
  1700. reduction='mean',
  1701. zero_infinity=False,
  1702. **kwargs):
  1703. FateTorchLoss.__init__(self)
  1704. self.param_dict['blank'] = blank
  1705. self.param_dict['reduction'] = reduction
  1706. self.param_dict['zero_infinity'] = zero_infinity
  1707. self.param_dict.update(kwargs)
  1708. nn.modules.loss.CTCLoss.__init__(self, **self.param_dict)
  1709. class CosineEmbeddingLoss(nn.modules.loss.CosineEmbeddingLoss, FateTorchLoss):
  1710. def __init__(
  1711. self,
  1712. margin=0.0,
  1713. size_average=None,
  1714. reduce=None,
  1715. reduction='mean',
  1716. **kwargs):
  1717. FateTorchLoss.__init__(self)
  1718. self.param_dict['margin'] = margin
  1719. self.param_dict['size_average'] = size_average
  1720. self.param_dict['reduce'] = reduce
  1721. self.param_dict['reduction'] = reduction
  1722. self.param_dict.update(kwargs)
  1723. nn.modules.loss.CosineEmbeddingLoss.__init__(self, **self.param_dict)
  1724. class CrossEntropyLoss(nn.modules.loss.CrossEntropyLoss, FateTorchLoss):
  1725. def __init__(
  1726. self,
  1727. weight=None,
  1728. size_average=None,
  1729. ignore_index=-100,
  1730. reduce=None,
  1731. reduction='mean',
  1732. label_smoothing=0.0,
  1733. **kwargs):
  1734. FateTorchLoss.__init__(self)
  1735. self.param_dict['weight'] = weight
  1736. self.param_dict['size_average'] = size_average
  1737. self.param_dict['ignore_index'] = ignore_index
  1738. self.param_dict['reduce'] = reduce
  1739. self.param_dict['reduction'] = reduction
  1740. self.param_dict['label_smoothing'] = label_smoothing
  1741. self.param_dict.update(kwargs)
  1742. nn.modules.loss.CrossEntropyLoss.__init__(self, **self.param_dict)
  1743. class GaussianNLLLoss(nn.modules.loss.GaussianNLLLoss, FateTorchLoss):
  1744. def __init__(self, **kwargs):
  1745. FateTorchLoss.__init__(self)
  1746. self.param_dict.update(kwargs)
  1747. nn.modules.loss.GaussianNLLLoss.__init__(self, **self.param_dict)
  1748. class HingeEmbeddingLoss(nn.modules.loss.HingeEmbeddingLoss, FateTorchLoss):
  1749. def __init__(
  1750. self,
  1751. margin=1.0,
  1752. size_average=None,
  1753. reduce=None,
  1754. reduction='mean',
  1755. **kwargs):
  1756. FateTorchLoss.__init__(self)
  1757. self.param_dict['margin'] = margin
  1758. self.param_dict['size_average'] = size_average
  1759. self.param_dict['reduce'] = reduce
  1760. self.param_dict['reduction'] = reduction
  1761. self.param_dict.update(kwargs)
  1762. nn.modules.loss.HingeEmbeddingLoss.__init__(self, **self.param_dict)
  1763. class HuberLoss(nn.modules.loss.HuberLoss, FateTorchLoss):
  1764. def __init__(self, reduction='mean', delta=1.0, **kwargs):
  1765. FateTorchLoss.__init__(self)
  1766. self.param_dict['reduction'] = reduction
  1767. self.param_dict['delta'] = delta
  1768. self.param_dict.update(kwargs)
  1769. nn.modules.loss.HuberLoss.__init__(self, **self.param_dict)
  1770. class KLDivLoss(nn.modules.loss.KLDivLoss, FateTorchLoss):
  1771. def __init__(
  1772. self,
  1773. size_average=None,
  1774. reduce=None,
  1775. reduction='mean',
  1776. log_target=False,
  1777. **kwargs):
  1778. FateTorchLoss.__init__(self)
  1779. self.param_dict['size_average'] = size_average
  1780. self.param_dict['reduce'] = reduce
  1781. self.param_dict['reduction'] = reduction
  1782. self.param_dict['log_target'] = log_target
  1783. self.param_dict.update(kwargs)
  1784. nn.modules.loss.KLDivLoss.__init__(self, **self.param_dict)
  1785. class L1Loss(nn.modules.loss.L1Loss, FateTorchLoss):
  1786. def __init__(
  1787. self,
  1788. size_average=None,
  1789. reduce=None,
  1790. reduction='mean',
  1791. **kwargs):
  1792. FateTorchLoss.__init__(self)
  1793. self.param_dict['size_average'] = size_average
  1794. self.param_dict['reduce'] = reduce
  1795. self.param_dict['reduction'] = reduction
  1796. self.param_dict.update(kwargs)
  1797. nn.modules.loss.L1Loss.__init__(self, **self.param_dict)
  1798. class MSELoss(nn.modules.loss.MSELoss, FateTorchLoss):
  1799. def __init__(
  1800. self,
  1801. size_average=None,
  1802. reduce=None,
  1803. reduction='mean',
  1804. **kwargs):
  1805. FateTorchLoss.__init__(self)
  1806. self.param_dict['size_average'] = size_average
  1807. self.param_dict['reduce'] = reduce
  1808. self.param_dict['reduction'] = reduction
  1809. self.param_dict.update(kwargs)
  1810. nn.modules.loss.MSELoss.__init__(self, **self.param_dict)
  1811. class MarginRankingLoss(nn.modules.loss.MarginRankingLoss, FateTorchLoss):
  1812. def __init__(
  1813. self,
  1814. margin=0.0,
  1815. size_average=None,
  1816. reduce=None,
  1817. reduction='mean',
  1818. **kwargs):
  1819. FateTorchLoss.__init__(self)
  1820. self.param_dict['margin'] = margin
  1821. self.param_dict['size_average'] = size_average
  1822. self.param_dict['reduce'] = reduce
  1823. self.param_dict['reduction'] = reduction
  1824. self.param_dict.update(kwargs)
  1825. nn.modules.loss.MarginRankingLoss.__init__(self, **self.param_dict)
  1826. class MultiLabelMarginLoss(
  1827. nn.modules.loss.MultiLabelMarginLoss,
  1828. FateTorchLoss):
  1829. def __init__(
  1830. self,
  1831. size_average=None,
  1832. reduce=None,
  1833. reduction='mean',
  1834. **kwargs):
  1835. FateTorchLoss.__init__(self)
  1836. self.param_dict['size_average'] = size_average
  1837. self.param_dict['reduce'] = reduce
  1838. self.param_dict['reduction'] = reduction
  1839. self.param_dict.update(kwargs)
  1840. nn.modules.loss.MultiLabelMarginLoss.__init__(self, **self.param_dict)
  1841. class MultiLabelSoftMarginLoss(
  1842. nn.modules.loss.MultiLabelSoftMarginLoss,
  1843. FateTorchLoss):
  1844. def __init__(
  1845. self,
  1846. weight=None,
  1847. size_average=None,
  1848. reduce=None,
  1849. reduction='mean',
  1850. **kwargs):
  1851. FateTorchLoss.__init__(self)
  1852. self.param_dict['weight'] = weight
  1853. self.param_dict['size_average'] = size_average
  1854. self.param_dict['reduce'] = reduce
  1855. self.param_dict['reduction'] = reduction
  1856. self.param_dict.update(kwargs)
  1857. nn.modules.loss.MultiLabelSoftMarginLoss.__init__(
  1858. self, **self.param_dict)
  1859. class MultiMarginLoss(nn.modules.loss.MultiMarginLoss, FateTorchLoss):
  1860. def __init__(
  1861. self,
  1862. p=1,
  1863. margin=1.0,
  1864. weight=None,
  1865. size_average=None,
  1866. reduce=None,
  1867. reduction='mean',
  1868. **kwargs):
  1869. FateTorchLoss.__init__(self)
  1870. self.param_dict['p'] = p
  1871. self.param_dict['margin'] = margin
  1872. self.param_dict['weight'] = weight
  1873. self.param_dict['size_average'] = size_average
  1874. self.param_dict['reduce'] = reduce
  1875. self.param_dict['reduction'] = reduction
  1876. self.param_dict.update(kwargs)
  1877. nn.modules.loss.MultiMarginLoss.__init__(self, **self.param_dict)
  1878. class NLLLoss(nn.modules.loss.NLLLoss, FateTorchLoss):
  1879. def __init__(
  1880. self,
  1881. weight=None,
  1882. size_average=None,
  1883. ignore_index=-100,
  1884. reduce=None,
  1885. reduction='mean',
  1886. **kwargs):
  1887. FateTorchLoss.__init__(self)
  1888. self.param_dict['weight'] = weight
  1889. self.param_dict['size_average'] = size_average
  1890. self.param_dict['ignore_index'] = ignore_index
  1891. self.param_dict['reduce'] = reduce
  1892. self.param_dict['reduction'] = reduction
  1893. self.param_dict.update(kwargs)
  1894. nn.modules.loss.NLLLoss.__init__(self, **self.param_dict)
  1895. class NLLLoss2d(nn.modules.loss.NLLLoss2d, FateTorchLoss):
  1896. def __init__(
  1897. self,
  1898. weight=None,
  1899. size_average=None,
  1900. ignore_index=-100,
  1901. reduce=None,
  1902. reduction='mean',
  1903. **kwargs):
  1904. FateTorchLoss.__init__(self)
  1905. self.param_dict['weight'] = weight
  1906. self.param_dict['size_average'] = size_average
  1907. self.param_dict['ignore_index'] = ignore_index
  1908. self.param_dict['reduce'] = reduce
  1909. self.param_dict['reduction'] = reduction
  1910. self.param_dict.update(kwargs)
  1911. nn.modules.loss.NLLLoss2d.__init__(self, **self.param_dict)
  1912. class PoissonNLLLoss(nn.modules.loss.PoissonNLLLoss, FateTorchLoss):
  1913. def __init__(
  1914. self,
  1915. log_input=True,
  1916. full=False,
  1917. size_average=None,
  1918. eps=1e-08,
  1919. reduce=None,
  1920. reduction='mean',
  1921. **kwargs):
  1922. FateTorchLoss.__init__(self)
  1923. self.param_dict['log_input'] = log_input
  1924. self.param_dict['full'] = full
  1925. self.param_dict['size_average'] = size_average
  1926. self.param_dict['eps'] = eps
  1927. self.param_dict['reduce'] = reduce
  1928. self.param_dict['reduction'] = reduction
  1929. self.param_dict.update(kwargs)
  1930. nn.modules.loss.PoissonNLLLoss.__init__(self, **self.param_dict)
  1931. class SmoothL1Loss(nn.modules.loss.SmoothL1Loss, FateTorchLoss):
  1932. def __init__(
  1933. self,
  1934. size_average=None,
  1935. reduce=None,
  1936. reduction='mean',
  1937. beta=1.0,
  1938. **kwargs):
  1939. FateTorchLoss.__init__(self)
  1940. self.param_dict['size_average'] = size_average
  1941. self.param_dict['reduce'] = reduce
  1942. self.param_dict['reduction'] = reduction
  1943. self.param_dict['beta'] = beta
  1944. self.param_dict.update(kwargs)
  1945. nn.modules.loss.SmoothL1Loss.__init__(self, **self.param_dict)
  1946. class SoftMarginLoss(nn.modules.loss.SoftMarginLoss, FateTorchLoss):
  1947. def __init__(
  1948. self,
  1949. size_average=None,
  1950. reduce=None,
  1951. reduction='mean',
  1952. **kwargs):
  1953. FateTorchLoss.__init__(self)
  1954. self.param_dict['size_average'] = size_average
  1955. self.param_dict['reduce'] = reduce
  1956. self.param_dict['reduction'] = reduction
  1957. self.param_dict.update(kwargs)
  1958. nn.modules.loss.SoftMarginLoss.__init__(self, **self.param_dict)
  1959. class TripletMarginLoss(nn.modules.loss.TripletMarginLoss, FateTorchLoss):
  1960. def __init__(
  1961. self,
  1962. margin=1.0,
  1963. p=2.0,
  1964. eps=1e-06,
  1965. swap=False,
  1966. size_average=None,
  1967. reduce=None,
  1968. reduction='mean',
  1969. **kwargs):
  1970. FateTorchLoss.__init__(self)
  1971. self.param_dict['margin'] = margin
  1972. self.param_dict['p'] = p
  1973. self.param_dict['eps'] = eps
  1974. self.param_dict['swap'] = swap
  1975. self.param_dict['size_average'] = size_average
  1976. self.param_dict['reduce'] = reduce
  1977. self.param_dict['reduction'] = reduction
  1978. self.param_dict.update(kwargs)
  1979. nn.modules.loss.TripletMarginLoss.__init__(self, **self.param_dict)
  1980. class TripletMarginWithDistanceLoss(
  1981. nn.modules.loss.TripletMarginWithDistanceLoss,
  1982. FateTorchLoss):
  1983. def __init__(self, **kwargs):
  1984. FateTorchLoss.__init__(self)
  1985. self.param_dict.update(kwargs)
  1986. nn.modules.loss.TripletMarginWithDistanceLoss.__init__(
  1987. self, **self.param_dict)
  1988. class _Loss(nn.modules.loss._Loss, FateTorchLoss):
  1989. def __init__(
  1990. self,
  1991. size_average=None,
  1992. reduce=None,
  1993. reduction='mean',
  1994. **kwargs):
  1995. FateTorchLoss.__init__(self)
  1996. self.param_dict['size_average'] = size_average
  1997. self.param_dict['reduce'] = reduce
  1998. self.param_dict['reduction'] = reduction
  1999. self.param_dict.update(kwargs)
  2000. nn.modules.loss._Loss.__init__(self, **self.param_dict)
  2001. class _WeightedLoss(nn.modules.loss._WeightedLoss, FateTorchLoss):
  2002. def __init__(
  2003. self,
  2004. weight=None,
  2005. size_average=None,
  2006. reduce=None,
  2007. reduction='mean',
  2008. **kwargs):
  2009. FateTorchLoss.__init__(self)
  2010. self.param_dict['weight'] = weight
  2011. self.param_dict['size_average'] = size_average
  2012. self.param_dict['reduce'] = reduce
  2013. self.param_dict['reduction'] = reduction
  2014. self.param_dict.update(kwargs)
  2015. nn.modules.loss._WeightedLoss.__init__(self, **self.param_dict)