optim.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. from torch import optim
  2. from federatedml.nn.backend.torch.base import FateTorchLayer, Sequential
  3. from federatedml.nn.backend.torch.base import FateTorchOptimizer
  4. class ASGD(optim.ASGD, FateTorchOptimizer):
  5. def __init__(
  6. self,
  7. params=None,
  8. lr=0.01,
  9. lambd=0.0001,
  10. alpha=0.75,
  11. t0=1000000.0,
  12. weight_decay=0,
  13. foreach=None,
  14. maximize=False,
  15. ):
  16. FateTorchOptimizer.__init__(self)
  17. self.param_dict['lr'] = lr
  18. self.param_dict['lambd'] = lambd
  19. self.param_dict['alpha'] = alpha
  20. self.param_dict['t0'] = t0
  21. self.param_dict['weight_decay'] = weight_decay
  22. self.param_dict['foreach'] = foreach
  23. self.param_dict['maximize'] = maximize
  24. self.torch_class = type(self).__bases__[0]
  25. if params is None:
  26. return
  27. params = self.check_params(params)
  28. self.torch_class.__init__(self, params, **self.param_dict)
  29. # optim.ASGD.__init__(self, **self.param_dict)
  30. def __repr__(self):
  31. try:
  32. return type(self).__bases__[0].__repr__(self)
  33. except BaseException:
  34. return 'Optimizer ASGD without initiated parameters'.format(type(self).__name__)
  35. class Adadelta(optim.Adadelta, FateTorchOptimizer):
  36. def __init__(self, params=None, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0, foreach=None, ):
  37. FateTorchOptimizer.__init__(self)
  38. self.param_dict['lr'] = lr
  39. self.param_dict['rho'] = rho
  40. self.param_dict['eps'] = eps
  41. self.param_dict['weight_decay'] = weight_decay
  42. self.param_dict['foreach'] = foreach
  43. self.torch_class = type(self).__bases__[0]
  44. if params is None:
  45. return
  46. params = self.check_params(params)
  47. self.torch_class.__init__(self, params, **self.param_dict)
  48. # optim.Adadelta.__init__(self, **self.param_dict)
  49. def __repr__(self):
  50. try:
  51. return type(self).__bases__[0].__repr__(self)
  52. except BaseException:
  53. return 'Optimizer Adadelta without initiated parameters'.format(type(self).__name__)
  54. class Adagrad(optim.Adagrad, FateTorchOptimizer):
  55. def __init__(
  56. self,
  57. params=None,
  58. lr=0.01,
  59. lr_decay=0,
  60. weight_decay=0,
  61. initial_accumulator_value=0,
  62. eps=1e-10,
  63. foreach=None,
  64. ):
  65. FateTorchOptimizer.__init__(self)
  66. self.param_dict['lr'] = lr
  67. self.param_dict['lr_decay'] = lr_decay
  68. self.param_dict['weight_decay'] = weight_decay
  69. self.param_dict['initial_accumulator_value'] = initial_accumulator_value
  70. self.param_dict['eps'] = eps
  71. self.param_dict['foreach'] = foreach
  72. self.torch_class = type(self).__bases__[0]
  73. if params is None:
  74. return
  75. params = self.check_params(params)
  76. self.torch_class.__init__(self, params, **self.param_dict)
  77. # optim.Adagrad.__init__(self, **self.param_dict)
  78. def __repr__(self):
  79. try:
  80. return type(self).__bases__[0].__repr__(self)
  81. except BaseException:
  82. return 'Optimizer Adagrad without initiated parameters'.format(type(self).__name__)
  83. class Adam(optim.Adam, FateTorchOptimizer):
  84. def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, ):
  85. FateTorchOptimizer.__init__(self)
  86. self.param_dict['lr'] = lr
  87. self.param_dict['betas'] = betas
  88. self.param_dict['eps'] = eps
  89. self.param_dict['weight_decay'] = weight_decay
  90. self.param_dict['amsgrad'] = amsgrad
  91. self.torch_class = type(self).__bases__[0]
  92. if params is None:
  93. return
  94. params = self.check_params(params)
  95. self.torch_class.__init__(self, params, **self.param_dict)
  96. # optim.Adam.__init__(self, **self.param_dict)
  97. def __repr__(self):
  98. try:
  99. return type(self).__bases__[0].__repr__(self)
  100. except BaseException:
  101. return 'Optimizer Adam without initiated parameters'.format(type(self).__name__)
  102. class AdamW(optim.AdamW, FateTorchOptimizer):
  103. def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False, ):
  104. FateTorchOptimizer.__init__(self)
  105. self.param_dict['lr'] = lr
  106. self.param_dict['betas'] = betas
  107. self.param_dict['eps'] = eps
  108. self.param_dict['weight_decay'] = weight_decay
  109. self.param_dict['amsgrad'] = amsgrad
  110. self.torch_class = type(self).__bases__[0]
  111. if params is None:
  112. return
  113. params = self.check_params(params)
  114. self.torch_class.__init__(self, params, **self.param_dict)
  115. # optim.AdamW.__init__(self, **self.param_dict)
  116. def __repr__(self):
  117. try:
  118. return type(self).__bases__[0].__repr__(self)
  119. except BaseException:
  120. return 'Optimizer AdamW without initiated parameters'.format(type(self).__name__)
  121. class Adamax(optim.Adamax, FateTorchOptimizer):
  122. def __init__(self, params=None, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, foreach=None, ):
  123. FateTorchOptimizer.__init__(self)
  124. self.param_dict['lr'] = lr
  125. self.param_dict['betas'] = betas
  126. self.param_dict['eps'] = eps
  127. self.param_dict['weight_decay'] = weight_decay
  128. self.param_dict['foreach'] = foreach
  129. self.torch_class = type(self).__bases__[0]
  130. if params is None:
  131. return
  132. params = self.check_params(params)
  133. self.torch_class.__init__(self, params, **self.param_dict)
  134. # optim.Adamax.__init__(self, **self.param_dict)
  135. def __repr__(self):
  136. try:
  137. return type(self).__bases__[0].__repr__(self)
  138. except BaseException:
  139. return 'Optimizer Adamax without initiated parameters'.format(type(self).__name__)
  140. class LBFGS(optim.LBFGS, FateTorchOptimizer):
  141. def __init__(
  142. self,
  143. params=None,
  144. lr=1,
  145. max_iter=20,
  146. max_eval=None,
  147. tolerance_grad=1e-07,
  148. tolerance_change=1e-09,
  149. history_size=100,
  150. line_search_fn=None,
  151. ):
  152. FateTorchOptimizer.__init__(self)
  153. self.param_dict['lr'] = lr
  154. self.param_dict['max_iter'] = max_iter
  155. self.param_dict['max_eval'] = max_eval
  156. self.param_dict['tolerance_grad'] = tolerance_grad
  157. self.param_dict['tolerance_change'] = tolerance_change
  158. self.param_dict['history_size'] = history_size
  159. self.param_dict['line_search_fn'] = line_search_fn
  160. self.torch_class = type(self).__bases__[0]
  161. if params is None:
  162. return
  163. params = self.check_params(params)
  164. self.torch_class.__init__(self, params, **self.param_dict)
  165. # optim.LBFGS.__init__(self, **self.param_dict)
  166. def __repr__(self):
  167. try:
  168. return type(self).__bases__[0].__repr__(self)
  169. except BaseException:
  170. return 'Optimizer LBFGS without initiated parameters'.format(type(self).__name__)
  171. class NAdam(optim.NAdam, FateTorchOptimizer):
  172. def __init__(
  173. self,
  174. params=None,
  175. lr=0.002,
  176. betas=(
  177. 0.9,
  178. 0.999),
  179. eps=1e-08,
  180. weight_decay=0,
  181. momentum_decay=0.004,
  182. foreach=None,
  183. ):
  184. FateTorchOptimizer.__init__(self)
  185. self.param_dict['lr'] = lr
  186. self.param_dict['betas'] = betas
  187. self.param_dict['eps'] = eps
  188. self.param_dict['weight_decay'] = weight_decay
  189. self.param_dict['momentum_decay'] = momentum_decay
  190. self.param_dict['foreach'] = foreach
  191. self.torch_class = type(self).__bases__[0]
  192. if params is None:
  193. return
  194. params = self.check_params(params)
  195. self.torch_class.__init__(self, params, **self.param_dict)
  196. # optim.NAdam.__init__(self, **self.param_dict)
  197. def __repr__(self):
  198. try:
  199. return type(self).__bases__[0].__repr__(self)
  200. except BaseException:
  201. return 'Optimizer NAdam without initiated parameters'.format(type(self).__name__)
  202. class RAdam(optim.RAdam, FateTorchOptimizer):
  203. def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, foreach=None, ):
  204. FateTorchOptimizer.__init__(self)
  205. self.param_dict['lr'] = lr
  206. self.param_dict['betas'] = betas
  207. self.param_dict['eps'] = eps
  208. self.param_dict['weight_decay'] = weight_decay
  209. self.param_dict['foreach'] = foreach
  210. self.torch_class = type(self).__bases__[0]
  211. if params is None:
  212. return
  213. params = self.check_params(params)
  214. self.torch_class.__init__(self, params, **self.param_dict)
  215. # optim.RAdam.__init__(self, **self.param_dict)
  216. def __repr__(self):
  217. try:
  218. return type(self).__bases__[0].__repr__(self)
  219. except BaseException:
  220. return 'Optimizer RAdam without initiated parameters'.format(type(self).__name__)
  221. class RMSprop(optim.RMSprop, FateTorchOptimizer):
  222. def __init__(
  223. self,
  224. params=None,
  225. lr=0.01,
  226. alpha=0.99,
  227. eps=1e-08,
  228. weight_decay=0,
  229. momentum=0,
  230. centered=False,
  231. foreach=None,
  232. maximize=False,
  233. differentiable=False,
  234. ):
  235. FateTorchOptimizer.__init__(self)
  236. self.param_dict['lr'] = lr
  237. self.param_dict['alpha'] = alpha
  238. self.param_dict['eps'] = eps
  239. self.param_dict['weight_decay'] = weight_decay
  240. self.param_dict['momentum'] = momentum
  241. self.param_dict['centered'] = centered
  242. self.param_dict['foreach'] = foreach
  243. self.param_dict['maximize'] = maximize
  244. self.param_dict['differentiable'] = differentiable
  245. self.torch_class = type(self).__bases__[0]
  246. if params is None:
  247. return
  248. params = self.check_params(params)
  249. self.torch_class.__init__(self, params, **self.param_dict)
  250. # optim.RMSprop.__init__(self, **self.param_dict)
  251. def __repr__(self):
  252. try:
  253. return type(self).__bases__[0].__repr__(self)
  254. except BaseException:
  255. return 'Optimizer RMSprop without initiated parameters'.format(type(self).__name__)
  256. class Rprop(optim.Rprop, FateTorchOptimizer):
  257. def __init__(self, params=None, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50), foreach=None, maximize=False, ):
  258. FateTorchOptimizer.__init__(self)
  259. self.param_dict['lr'] = lr
  260. self.param_dict['etas'] = etas
  261. self.param_dict['step_sizes'] = step_sizes
  262. self.param_dict['foreach'] = foreach
  263. self.param_dict['maximize'] = maximize
  264. self.torch_class = type(self).__bases__[0]
  265. if params is None:
  266. return
  267. params = self.check_params(params)
  268. self.torch_class.__init__(self, params, **self.param_dict)
  269. # optim.Rprop.__init__(self, **self.param_dict)
  270. def __repr__(self):
  271. try:
  272. return type(self).__bases__[0].__repr__(self)
  273. except BaseException:
  274. return 'Optimizer Rprop without initiated parameters'.format(type(self).__name__)
  275. class SGD(optim.SGD, FateTorchOptimizer):
  276. def __init__(self, params=None, lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False, ):
  277. FateTorchOptimizer.__init__(self)
  278. self.param_dict['lr'] = lr
  279. self.param_dict['momentum'] = momentum
  280. self.param_dict['dampening'] = dampening
  281. self.param_dict['weight_decay'] = weight_decay
  282. self.param_dict['nesterov'] = nesterov
  283. self.torch_class = type(self).__bases__[0]
  284. if params is None:
  285. return
  286. params = self.check_params(params)
  287. self.torch_class.__init__(self, params, **self.param_dict)
  288. # optim.SGD.__init__(self, **self.param_dict)
  289. def __repr__(self):
  290. try:
  291. return type(self).__bases__[0].__repr__(self)
  292. except BaseException:
  293. return 'Optimizer SGD without initiated parameters'.format(type(self).__name__)
  294. class SparseAdam(optim.SparseAdam, FateTorchOptimizer):
  295. def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08, maximize=False, ):
  296. FateTorchOptimizer.__init__(self)
  297. self.param_dict['lr'] = lr
  298. self.param_dict['betas'] = betas
  299. self.param_dict['eps'] = eps
  300. self.param_dict['maximize'] = maximize
  301. self.torch_class = type(self).__bases__[0]
  302. if params is None:
  303. return
  304. params = self.check_params(params)
  305. self.torch_class.__init__(self, params, **self.param_dict)
  306. # optim.SparseAdam.__init__(self, **self.param_dict)
  307. def __repr__(self):
  308. try:
  309. return type(self).__bases__[0].__repr__(self)
  310. except BaseException:
  311. return 'Optimizer SparseAdam without initiated parameters'.format(type(self).__name__)