one_vs_rest.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import functools
  18. from federatedml.model_base import ModelBase
  19. from federatedml.transfer_variable.transfer_class.one_vs_rest_transfer_variable import OneVsRestTransferVariable
  20. from federatedml.util import LOGGER
  21. from federatedml.util import consts
  22. from federatedml.util.classify_label_checker import ClassifyLabelChecker
  23. from federatedml.util.io_check import assert_io_num_rows_equal
  24. class OneVsRest(object):
  25. def __init__(self, classifier, role, mode, has_arbiter):
  26. self.classifier = classifier
  27. self.transfer_variable = OneVsRestTransferVariable()
  28. self.classes = None
  29. self.role = role
  30. self.mode = mode
  31. self.flow_id = 0
  32. self.has_arbiter = has_arbiter
  33. self.models = []
  34. self.class_name = self.__class__.__name__
  35. @staticmethod
  36. def __get_multi_class_res(instance, classes):
  37. """
  38. return max_prob and its class where max_prob is the max probably in input instance
  39. """
  40. max_prob = -1
  41. max_prob_index = -1
  42. instance_with_class = {}
  43. for (i, prob) in enumerate(instance):
  44. instance_with_class[classes[i]] = prob
  45. if prob > max_prob:
  46. max_prob = prob
  47. max_prob_index = i
  48. return classes[max_prob_index], max_prob, instance_with_class
  49. def get_data_classes(self, data_instances):
  50. """
  51. get all classes in data_instances
  52. """
  53. class_set = None
  54. if self.has_label:
  55. num_class, class_list = ClassifyLabelChecker.validate_label(data_instances)
  56. class_set = set(class_list)
  57. self._synchronize_classes_list(class_set)
  58. return self.classes
  59. @staticmethod
  60. def _mask_data_label(data_instances, label):
  61. """
  62. mask the instance.label to 1 if equals to label and 0 if not
  63. """
  64. def do_mask_label(instance):
  65. instance.label = (1 if (instance.label == label) else 0)
  66. return instance
  67. f = functools.partial(do_mask_label)
  68. data_instances = data_instances.mapValues(f)
  69. return data_instances
  70. def _sync_class_guest(self, class_set):
  71. raise NotImplementedError("Function should not be called here")
  72. def _sync_class_host(self, class_set):
  73. raise NotImplementedError("Function should not be called here")
  74. def _sync_class_arbiter(self):
  75. raise NotImplementedError("Function should not be called here")
  76. def _synchronize_classes_list(self, class_set):
  77. """
  78. Guest will get classes from host data, and aggregate classes it has. After that, send the aggregate classes to
  79. host and arbiter as binary classification times.
  80. """
  81. if self.role == consts.GUEST:
  82. self._sync_class_guest(class_set)
  83. elif self.role == consts.HOST:
  84. self._sync_class_host(class_set)
  85. else:
  86. self._sync_class_arbiter()
  87. @property
  88. def has_label(self):
  89. raise NotImplementedError("Function should not be called here")
  90. def fit(self, data_instances=None, validate_data=None):
  91. """
  92. Fit OneVsRest model
  93. Parameters:
  94. ----------
  95. data_instances: Table of instances
  96. """
  97. LOGGER.info("mode is {}, role is {}, start to one_vs_rest fit".format(self.mode, self.role))
  98. LOGGER.info("Total classes:{}".format(self.classes))
  99. self.classifier.callback_one_vs_rest = True
  100. current_flow_id = self.classifier.flowid
  101. summary_dict = {}
  102. for label_index, label in enumerate(self.classes):
  103. LOGGER.info("Start to train OneVsRest with label_index:{}, label:{}".format(label_index, label))
  104. classifier = copy.deepcopy(self.classifier)
  105. classifier.need_one_vs_rest = False
  106. classifier.set_flowid(".".join([current_flow_id, "model_" + str(label_index)]))
  107. if self.has_label:
  108. header = data_instances.schema.get("header")
  109. data_instances_mask_label = self._mask_data_label(data_instances, label=label)
  110. data_instances_mask_label.schema['header'] = header
  111. if validate_data is not None:
  112. validate_mask_label_data = self._mask_data_label(validate_data, label=label)
  113. validate_mask_label_data.schema['header'] = header
  114. else:
  115. validate_mask_label_data = validate_data
  116. LOGGER.info("finish mask label:{}".format(label))
  117. LOGGER.info("start classifier fit")
  118. classifier.fit_binary(data_instances_mask_label, validate_data=validate_mask_label_data)
  119. else:
  120. LOGGER.info("start classifier fit")
  121. classifier.fit_binary(data_instances, validate_data=validate_data)
  122. _summary = classifier.summary()
  123. _summary['one_vs_rest'] = True
  124. summary_dict[label] = _summary
  125. self.models.append(classifier)
  126. if hasattr(self, "header"):
  127. header = getattr(self, "header")
  128. if header is None:
  129. setattr(self, "header", getattr(classifier, "header"))
  130. LOGGER.info("Finish model_{} training!".format(label_index))
  131. self.classifier.set_summary(summary_dict)
  132. def _comprehensive_result(self, predict_res_list):
  133. """
  134. prob result is available for guest party only.
  135. """
  136. if self.role == consts.GUEST:
  137. # assert 1 == 2, f"predict_res_list: {predict_res_list[0].first()[1].features}"
  138. prob = predict_res_list[0].mapValues(lambda r: [r.features[2]])
  139. for predict_res in predict_res_list[1:]:
  140. prob = prob.join(predict_res, lambda p, r: p + [r.features[2]])
  141. else:
  142. prob = None
  143. return prob
  144. @assert_io_num_rows_equal
  145. def predict(self, data_instances):
  146. """
  147. Predict OneVsRest model
  148. Parameters:
  149. ----------
  150. data_instances: Table of instances
  151. predict_param: PredictParam of classifier
  152. Returns:
  153. ----------
  154. predict_res: Table, if has predict_res, it includes ground true label, predict probably and predict label
  155. """
  156. LOGGER.info("Start one_vs_all predict procedure.")
  157. predict_res_list = []
  158. for i, model in enumerate(self.models):
  159. current_flow_id = model.flowid
  160. model.set_flowid(".".join([current_flow_id, "model_" + str(i)]))
  161. LOGGER.info("Start to predict with model:{}".format(i))
  162. # model.set_flowid("predict_" + str(i))
  163. single_predict_res = model.predict(data_instances)
  164. predict_res_list.append(single_predict_res)
  165. prob = self._comprehensive_result(predict_res_list)
  166. if prob:
  167. # f = functools.partial(self.__get_multi_class_res, classes=list(self.classes))
  168. # multi_classes_res = prob.mapValues(f)
  169. # predict_res = data_instances.join(multi_classes_res, lambda d, m: [d.label, m[0], m[1], m[2]])
  170. # def _transfer(instance, pred_res):
  171. # return Instance(features=pred_res, inst_id=instance.inst_id)
  172. # predict_res = data_instances.join(predict_res, _transfer)
  173. predict_res = ModelBase.predict_score_to_output(data_instances, prob, list(self.classes))
  174. else:
  175. predict_res = None
  176. #
  177. # LOGGER.info("finish OneVsRest Predict, return predict results.")
  178. return predict_res
  179. def save(self, single_model_pb):
  180. """
  181. Save each classifier model of OneVsRest. It just include model_param but not model_meta now
  182. """
  183. classifier_pb_objs = []
  184. for classifier in self.models:
  185. single_param_dict = classifier.get_single_model_param()
  186. classifier_pb_objs.append(single_model_pb(**single_param_dict))
  187. one_vs_rest_class = [str(x) for x in self.classes]
  188. one_vs_rest_result = {
  189. 'completed_models': classifier_pb_objs,
  190. 'one_vs_rest_classes': one_vs_rest_class
  191. }
  192. return one_vs_rest_result
  193. def load_model(self, one_vs_rest_result):
  194. """
  195. Load OneVsRest model
  196. """
  197. completed_models = one_vs_rest_result.completed_models
  198. one_vs_rest_classes = one_vs_rest_result.one_vs_rest_classes
  199. self.classes = [int(x) for x in one_vs_rest_classes] # Support other label type in the future
  200. self.models = []
  201. for classifier_obj in list(completed_models):
  202. classifier = copy.deepcopy(self.classifier)
  203. classifier.load_single_model(classifier_obj)
  204. classifier.need_one_vs_rest = False
  205. self.models.append(classifier)
  206. class HomoOneVsRest(OneVsRest):
  207. def __init__(self, classifier, role, mode, has_arbiter):
  208. super().__init__(classifier, role, mode, has_arbiter)
  209. self.header = None
  210. def set_header(self, header):
  211. self.header = header
  212. @property
  213. def has_label(self):
  214. if self.role == consts.ARBITER:
  215. return False
  216. return True
  217. def _sync_class_guest(self, class_set):
  218. host_classes_list = self.transfer_variable.host_classes.get(idx=-1)
  219. for host_class in host_classes_list:
  220. class_set = class_set | host_class
  221. self.classes = list(class_set)
  222. self.transfer_variable.aggregate_classes.remote(self.classes,
  223. role=consts.HOST,
  224. idx=-1)
  225. if self.has_arbiter:
  226. class_num = len(self.classes)
  227. self.transfer_variable.aggregate_classes.remote(class_num,
  228. role=consts.ARBITER,
  229. idx=0)
  230. def _sync_class_host(self, class_set):
  231. self.transfer_variable.host_classes.remote(class_set,
  232. role=consts.GUEST,
  233. idx=0)
  234. self.classes = self.transfer_variable.aggregate_classes.get(idx=0)
  235. def _sync_class_arbiter(self):
  236. class_nums = self.transfer_variable.aggregate_classes.get(idx=0)
  237. self.classes = [x for x in range(class_nums)]
  238. class HeteroOneVsRest(OneVsRest):
  239. @property
  240. def has_label(self):
  241. if self.role == consts.GUEST:
  242. return True
  243. return False
  244. def _sync_class_guest(self, class_set):
  245. self.classes = list(class_set)
  246. class_num = len(self.classes)
  247. self.transfer_variable.aggregate_classes.remote(class_num,
  248. role=consts.HOST,
  249. idx=-1)
  250. if self.has_arbiter:
  251. self.transfer_variable.aggregate_classes.remote(class_num,
  252. role=consts.ARBITER,
  253. idx=0)
  254. def _sync_class_host(self, class_set):
  255. LOGGER.debug("Start to get aggregate classes")
  256. class_nums = self.transfer_variable.aggregate_classes.get(idx=0)
  257. self.classes = [x for x in range(class_nums)]
  258. def _sync_class_arbiter(self):
  259. class_nums = self.transfer_variable.aggregate_classes.get(idx=0)
  260. self.classes = [x for x in range(class_nums)]
  261. def one_vs_rest_factory(classifier, role, mode, has_arbiter):
  262. LOGGER.info("Create one_vs_rest object, role: {}, mode: {}".format(role, mode))
  263. if mode == consts.HOMO:
  264. return HomoOneVsRest(classifier, role, mode, has_arbiter)
  265. elif mode == consts.HETERO:
  266. return HeteroOneVsRest(classifier, role, mode, has_arbiter)
  267. else:
  268. raise ValueError(f"Cannot recognize mode: {mode} in one vs rest")