linear_model_base.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.model_base import Metric
  20. from federatedml.model_base import MetricMeta
  21. from federatedml.feature.sparse_vector import SparseVector
  22. from federatedml.model_base import ModelBase
  23. from federatedml.model_selection import start_cross_validation
  24. from federatedml.model_selection.stepwise import start_stepwise
  25. from federatedml.optim.convergence import converge_func_factory
  26. from federatedml.optim.initialize import Initializer
  27. from federatedml.optim.optimizer import optimizer_factory
  28. from federatedml.statistic import data_overview
  29. from federatedml.util import LOGGER
  30. from federatedml.util import abnormal_detection
  31. from federatedml.util import consts
  32. from federatedml.callbacks.validation_strategy import ValidationStrategy
  33. class BaseLinearModel(ModelBase):
  34. def __init__(self):
  35. super(BaseLinearModel, self).__init__()
  36. # attribute:
  37. self.n_iter_ = 0
  38. self.classes_ = None
  39. self.feature_shape = None
  40. self.gradient_operator = None
  41. self.initializer = Initializer()
  42. self.transfer_variable = None
  43. self.loss_history = []
  44. self.is_converged = False
  45. self.header = None
  46. self.model_name = 'toSet'
  47. self.model_param_name = 'toSet'
  48. self.model_meta_name = 'toSet'
  49. self.role = ''
  50. self.mode = ''
  51. self.schema = {}
  52. self.cipher_operator = None
  53. self.model_weights = None
  54. self.validation_freqs = None
  55. self.need_one_vs_rest = False
  56. self.need_call_back_loss = True
  57. self.init_param_obj = None
  58. self.early_stop = None
  59. self.tol = None
  60. def _init_model(self, params):
  61. self.model_param = params
  62. self.alpha = params.alpha
  63. self.init_param_obj = params.init_param
  64. # self.fit_intercept = self.init_param_obj.fit_intercept
  65. self.batch_size = params.batch_size
  66. if hasattr(params, "shuffle"):
  67. self.shuffle = params.shuffle
  68. if hasattr(params, "masked_rate"):
  69. self.masked_rate = params.masked_rate
  70. if hasattr(params, "batch_strategy"):
  71. self.batch_strategy = params.batch_strategy
  72. self.max_iter = params.max_iter
  73. self.optimizer = optimizer_factory(params)
  74. self.early_stop = params.early_stop
  75. self.tol = params.tol
  76. self.converge_func = converge_func_factory(params.early_stop, params.tol)
  77. self.validation_freqs = params.callback_param.validation_freqs
  78. self.validation_strategy = None
  79. self.early_stopping_rounds = params.callback_param.early_stopping_rounds
  80. self.metrics = params.callback_param.metrics
  81. self.use_first_metric_only = params.callback_param.use_first_metric_only
  82. # if len(self.component_properties.host_party_idlist) == 1:
  83. # LOGGER.debug(f"set_use_async")
  84. # self.gradient_loss_operator.set_use_async()
  85. def get_features_shape(self, data_instances):
  86. if self.feature_shape is not None:
  87. return self.feature_shape
  88. return data_overview.get_features_shape(data_instances)
  89. def set_header(self, header):
  90. self.header = header
  91. def get_header(self, data_instances):
  92. if self.header is not None:
  93. return self.header
  94. return data_instances.schema.get("header", [])
  95. @property
  96. def fit_intercept(self):
  97. return self.init_param_obj.fit_intercept
  98. def _get_meta(self):
  99. raise NotImplementedError("This method should be be called here")
  100. def _get_param(self):
  101. raise NotImplementedError("This method should be be called here")
  102. def export_model(self):
  103. LOGGER.debug(f"called export model")
  104. meta_obj = self._get_meta()
  105. param_obj = self._get_param()
  106. result = {
  107. self.model_meta_name: meta_obj,
  108. self.model_param_name: param_obj
  109. }
  110. return result
  111. def disable_callback_loss(self):
  112. self.need_call_back_loss = False
  113. def enable_callback_loss(self):
  114. self.need_call_back_loss = True
  115. def callback_loss(self, iter_num, loss):
  116. metric_meta = MetricMeta(name='train',
  117. metric_type="LOSS",
  118. extra_metas={
  119. "unit_name": "iters",
  120. })
  121. self.callback_meta(metric_name='loss', metric_namespace='train', metric_meta=metric_meta)
  122. self.callback_metric(metric_name='loss',
  123. metric_namespace='train',
  124. metric_data=[Metric(iter_num, loss)])
  125. def _abnormal_detection(self, data_instances):
  126. """
  127. Make sure input data_instances is valid.
  128. """
  129. abnormal_detection.empty_table_detection(data_instances)
  130. abnormal_detection.empty_feature_detection(data_instances)
  131. ModelBase.check_schema_content(data_instances.schema)
  132. def init_validation_strategy(self, train_data=None, validate_data=None):
  133. validation_strategy = ValidationStrategy(self.role, self.mode, self.validation_freqs,
  134. self.early_stopping_rounds,
  135. self.use_first_metric_only)
  136. validation_strategy.set_train_data(train_data)
  137. validation_strategy.set_validate_data(validate_data)
  138. return validation_strategy
  139. def cross_validation(self, data_instances):
  140. return start_cross_validation.run(self, data_instances)
  141. def stepwise(self, data_instances):
  142. self.disable_callback_loss()
  143. return start_stepwise.run(self, data_instances)
  144. def _get_cv_param(self):
  145. self.model_param.cv_param.role = self.role
  146. self.model_param.cv_param.mode = self.mode
  147. return self.model_param.cv_param
  148. def _get_stepwise_param(self):
  149. self.model_param.stepwise_param.role = self.role
  150. self.model_param.stepwise_param.mode = self.mode
  151. return self.model_param.stepwise_param
  152. def set_schema(self, data_instance, header=None):
  153. if header is None:
  154. self.schema["header"] = self.header
  155. else:
  156. self.schema["header"] = header
  157. data_instance.schema = self.schema
  158. return data_instance
  159. def init_schema(self, data_instance):
  160. if data_instance is None:
  161. return
  162. self.schema = data_instance.schema
  163. self.header = self.schema.get('header')
  164. def get_weight_intercept_dict(self, header):
  165. weight_dict = {}
  166. for idx, header_name in enumerate(header):
  167. coef_i = self.model_weights.coef_[idx]
  168. weight_dict[header_name] = coef_i
  169. intercept_ = self.model_weights.intercept_
  170. return weight_dict, intercept_
  171. def get_model_summary(self):
  172. header = self.header
  173. if header is None:
  174. return {}
  175. weight_dict, intercept_ = self.get_weight_intercept_dict(header)
  176. summary = {"coef": weight_dict,
  177. "intercept": intercept_,
  178. "is_converged": self.is_converged,
  179. "best_iteration": self.callback_variables.best_iteration}
  180. if self.callback_variables.validation_summary is not None:
  181. summary["validation_metrics"] = self.callback_variables.validation_summary
  182. return summary
  183. def check_abnormal_values(self, data_instances):
  184. if data_instances is None:
  185. return
  186. def _check_overflow(data_iter):
  187. for _, instant in data_iter:
  188. features = instant.features
  189. if isinstance(features, SparseVector):
  190. sparse_data = features.get_all_data()
  191. for k, v in sparse_data:
  192. if np.abs(v) > consts.OVERFLOW_THRESHOLD:
  193. return True
  194. else:
  195. if np.max(np.abs(features)) > consts.OVERFLOW_THRESHOLD:
  196. return True
  197. return False
  198. check_status = data_instances.applyPartitions(_check_overflow)
  199. is_overflow = check_status.reduce(lambda a, b: a or b)
  200. if is_overflow:
  201. raise OverflowError("The value range of features is too large for GLM, please have "
  202. "a check for input data")
  203. LOGGER.info("Check for abnormal value passed")
  204. def prepare_fit(self, data_instances, validate_data):
  205. self.header = self.get_header(data_instances)
  206. self._abnormal_detection(data_instances)
  207. self.check_abnormal_values(data_instances)
  208. self.check_abnormal_values(validate_data)