convergence.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. from federatedml.util import LOGGER
  18. from federatedml.util import fate_operator
  19. class _ConvergeFunction:
  20. def __init__(self, eps):
  21. self.eps = eps
  22. def is_converge(self, loss): pass
  23. class _DiffConverge(_ConvergeFunction):
  24. """
  25. Judge convergence by the difference between two iterations.
  26. If the difference is smaller than eps, converge flag will be provided.
  27. """
  28. def __init__(self, eps):
  29. super().__init__(eps=eps)
  30. self.pre_loss = None
  31. def is_converge(self, loss):
  32. LOGGER.debug("In diff converge function, pre_loss: {}, current_loss: {}".format(self.pre_loss, loss))
  33. converge_flag = False
  34. if self.pre_loss is None:
  35. pass
  36. elif abs(self.pre_loss - loss) < self.eps:
  37. converge_flag = True
  38. self.pre_loss = loss
  39. return converge_flag
  40. class _AbsConverge(_ConvergeFunction):
  41. """
  42. Judge converge by absolute loss value. When loss value smaller than eps, converge flag
  43. will be provided.
  44. """
  45. def is_converge(self, loss):
  46. if loss <= self.eps:
  47. converge_flag = True
  48. else:
  49. converge_flag = False
  50. return converge_flag
  51. class _WeightDiffConverge(_ConvergeFunction):
  52. """
  53. Use 2-norm of weight difference to judge whether converge or not.
  54. """
  55. def __init__(self, eps):
  56. super().__init__(eps=eps)
  57. self.pre_weight = None
  58. def is_converge(self, weight):
  59. if self.pre_weight is None:
  60. self.pre_weight = weight
  61. return False
  62. weight_diff = fate_operator.norm(self.pre_weight - weight)
  63. self.pre_weight = weight
  64. if weight_diff < self.eps * np.max([fate_operator.norm(weight), 1]):
  65. return True
  66. return False
  67. def converge_func_factory(early_stop, tol):
  68. # try:
  69. # converge_func = param.converge_func
  70. # eps = param.eps
  71. # except AttributeError:
  72. # raise AttributeError("Converge Function parameters has not been totally set")
  73. if early_stop == 'diff':
  74. return _DiffConverge(tol)
  75. elif early_stop == 'weight_diff':
  76. return _WeightDiffConverge(tol)
  77. elif early_stop == 'abs':
  78. return _AbsConverge(tol)
  79. else:
  80. raise NotImplementedError("Converge Function method cannot be recognized: {}".format(early_stop))