k_fold.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import functools
  18. import numpy as np
  19. from sklearn.model_selection import KFold as sk_KFold
  20. from fate_arch.session import computing_session as session
  21. from federatedml.evaluation.evaluation import Evaluation
  22. from federatedml.model_selection.cross_validate import BaseCrossValidator
  23. from federatedml.model_selection.indices import collect_index
  24. from federatedml.transfer_variable.transfer_class.cross_validation_transfer_variable import \
  25. CrossValidationTransferVariable
  26. from federatedml.util import LOGGER
  27. from federatedml.util import consts
  28. class KFold(BaseCrossValidator):
  29. def __init__(self):
  30. super(KFold, self).__init__()
  31. self.model_param = None
  32. self.n_splits = 1
  33. self.shuffle = True
  34. self.random_seed = 1
  35. self.fold_history = None
  36. def _init_model(self, param):
  37. self.model_param = param
  38. self.n_splits = param.n_splits
  39. self.mode = param.mode
  40. self.role = param.role
  41. self.shuffle = param.shuffle
  42. self.random_seed = param.random_seed
  43. self.output_fold_history = param.output_fold_history
  44. self.history_value_type = param.history_value_type
  45. # self.evaluate_param = param.evaluate_param
  46. # np.random.seed(self.random_seed)
  47. def split(self, data_inst):
  48. # header = data_inst.schema.get('header')
  49. schema = data_inst.schema
  50. data_sids_iter, data_size = collect_index(data_inst)
  51. data_sids = []
  52. key_type = None
  53. for sid, _ in data_sids_iter:
  54. if key_type is None:
  55. key_type = type(sid)
  56. data_sids.append(sid)
  57. data_sids = np.array(data_sids)
  58. # if self.shuffle:
  59. # np.random.shuffle(data_sids)
  60. random_state = self.random_seed if self.shuffle else None
  61. kf = sk_KFold(n_splits=self.n_splits, shuffle=self.shuffle, random_state=random_state)
  62. n = 0
  63. for train, test in kf.split(data_sids):
  64. train_sids = data_sids[train]
  65. test_sids = data_sids[test]
  66. n += 1
  67. train_sids_table = [(key_type(x), 1) for x in train_sids]
  68. test_sids_table = [(key_type(x), 1) for x in test_sids]
  69. train_table = session.parallelize(train_sids_table,
  70. include_key=True,
  71. partition=data_inst.partitions)
  72. train_data = data_inst.join(train_table, lambda x, y: x)
  73. test_table = session.parallelize(test_sids_table,
  74. include_key=True,
  75. partition=data_inst.partitions)
  76. test_data = data_inst.join(test_table, lambda x, y: x)
  77. train_data.schema = schema
  78. test_data.schema = schema
  79. yield train_data, test_data
  80. @staticmethod
  81. def generate_new_id(id, fold_num, data_type):
  82. return f"{id}#fold{fold_num}#{data_type}"
  83. def transform_history_data(self, data, predict_data, fold_num, data_type):
  84. if self.history_value_type == "score":
  85. if predict_data is not None:
  86. history_data = predict_data.map(lambda k, v: (KFold.generate_new_id(k, fold_num, data_type), v))
  87. history_data.schema = copy.deepcopy(predict_data.schema)
  88. else:
  89. history_data = data.map(lambda k, v: (KFold.generate_new_id(k, fold_num, data_type), fold_num))
  90. schema = copy.deepcopy(data.schema)
  91. schema["header"] = ["fold_num"]
  92. history_data.schema = schema
  93. elif self.history_value_type == "instance":
  94. history_data = data.map(lambda k, v: (KFold.generate_new_id(k, fold_num, data_type), v))
  95. history_data.schema = copy.deepcopy(data.schema)
  96. else:
  97. raise ValueError(f"unknown history value type")
  98. return history_data
  99. @staticmethod
  100. def _append_name(instance, name):
  101. new_inst = copy.deepcopy(instance)
  102. new_inst.features.append(name)
  103. return new_inst
  104. def run(self, component_parameters, data_inst, original_model, host_do_evaluate):
  105. self._init_model(component_parameters)
  106. if data_inst is None:
  107. self._arbiter_run(original_model)
  108. return
  109. total_data_count = data_inst.count()
  110. LOGGER.debug(f"data_inst count: {total_data_count}")
  111. if self.output_fold_history:
  112. if total_data_count * self.n_splits > consts.MAX_SAMPLE_OUTPUT_LIMIT:
  113. LOGGER.warning(
  114. f"max sample output limit {consts.MAX_SAMPLE_OUTPUT_LIMIT} exceeded with n_splits ({self.n_splits}) * instance_count ({total_data_count})")
  115. if self.mode == consts.HOMO or self.role == consts.GUEST:
  116. data_generator = self.split(data_inst)
  117. else:
  118. data_generator = [(data_inst, data_inst)] * self.n_splits
  119. fold_num = 0
  120. summary_res = {}
  121. for train_data, test_data in data_generator:
  122. model = copy.deepcopy(original_model)
  123. LOGGER.debug("In CV, set_flowid flowid is : {}".format(fold_num))
  124. model.set_flowid(fold_num)
  125. model.set_cv_fold(fold_num)
  126. LOGGER.info("KFold fold_num is: {}".format(fold_num))
  127. if self.mode == consts.HETERO:
  128. train_data = self._align_data_index(train_data, model.flowid, consts.TRAIN_DATA)
  129. LOGGER.info("Train data Synchronized")
  130. test_data = self._align_data_index(test_data, model.flowid, consts.TEST_DATA)
  131. LOGGER.info("Test data Synchronized")
  132. train_data_count = train_data.count()
  133. test_data_count = test_data.count()
  134. LOGGER.debug(f"train_data count: {train_data_count}")
  135. if train_data_count + test_data_count != total_data_count:
  136. raise EnvironmentError("In cv fold: {}, train count: {}, test count: {}, original data count: {}."
  137. "Thus, 'train count + test count = total count' condition is not satisfied"
  138. .format(fold_num, train_data_count, test_data_count, total_data_count))
  139. this_flowid = 'train.' + str(fold_num)
  140. LOGGER.debug("In CV, set_flowid flowid is : {}".format(this_flowid))
  141. model.set_flowid(this_flowid)
  142. model.fit(train_data, test_data)
  143. this_flowid = 'predict_train.' + str(fold_num)
  144. LOGGER.debug("In CV, set_flowid flowid is : {}".format(this_flowid))
  145. model.set_flowid(this_flowid)
  146. train_pred_res = model.predict(train_data)
  147. # if train_pred_res is not None:
  148. if self.role == consts.GUEST or host_do_evaluate:
  149. fold_name = "_".join(['train', 'fold', str(fold_num)])
  150. f = functools.partial(self._append_name, name='train')
  151. train_pred_res = train_pred_res.mapValues(f)
  152. train_pred_res = model.set_predict_data_schema(train_pred_res, train_data.schema)
  153. # LOGGER.debug(f"train_pred_res schema: {train_pred_res.schema}")
  154. self.evaluate(train_pred_res, fold_name, model)
  155. this_flowid = 'predict_validate.' + str(fold_num)
  156. LOGGER.debug("In CV, set_flowid flowid is : {}".format(this_flowid))
  157. model.set_flowid(this_flowid)
  158. test_pred_res = model.predict(test_data)
  159. # if pred_res is not None:
  160. if self.role == consts.GUEST or host_do_evaluate:
  161. fold_name = "_".join(['validate', 'fold', str(fold_num)])
  162. f = functools.partial(self._append_name, name='validate')
  163. test_pred_res = test_pred_res.mapValues(f)
  164. test_pred_res = model.set_predict_data_schema(test_pred_res, test_data.schema)
  165. # LOGGER.debug(f"train_pred_res schema: {test_pred_res.schema}")
  166. self.evaluate(test_pred_res, fold_name, model)
  167. LOGGER.debug("Finish fold: {}".format(fold_num))
  168. if self.output_fold_history:
  169. LOGGER.debug(f"generating fold history for fold {fold_num}")
  170. fold_train_data = self.transform_history_data(train_data, train_pred_res, fold_num, "train")
  171. fold_validate_data = self.transform_history_data(test_data, test_pred_res, fold_num, "validate")
  172. fold_history_data = fold_train_data.union(fold_validate_data)
  173. fold_history_data.schema = fold_train_data.schema
  174. if self.fold_history is None:
  175. self.fold_history = fold_history_data
  176. else:
  177. new_fold_history = self.fold_history.union(fold_history_data)
  178. new_fold_history.schema = fold_history_data.schema
  179. self.fold_history = new_fold_history
  180. summary_res[f"fold_{fold_num}"] = model.summary()
  181. fold_num += 1
  182. summary_res['fold_num'] = fold_num
  183. LOGGER.debug("Finish all fold running")
  184. original_model.set_summary(summary_res)
  185. if self.output_fold_history:
  186. LOGGER.debug(f"output data schema: {self.fold_history.schema}")
  187. # LOGGER.debug(f"output data: {list(self.fold_history.collect())}")
  188. # LOGGER.debug(f"output data is: {self.fold_history}")
  189. return self.fold_history
  190. else:
  191. return data_inst
  192. def _arbiter_run(self, original_model):
  193. for fold_num in range(self.n_splits):
  194. LOGGER.info("KFold flowid is: {}".format(fold_num))
  195. model = copy.deepcopy(original_model)
  196. this_flowid = 'train.' + str(fold_num)
  197. model.set_flowid(this_flowid)
  198. model.set_cv_fold(fold_num)
  199. model.fit(None)
  200. this_flowid = 'predict_train.' + str(fold_num)
  201. model.set_flowid(this_flowid)
  202. model.predict(None)
  203. this_flowid = 'predict_validate.' + str(fold_num)
  204. model.set_flowid(this_flowid)
  205. model.predict(None)
  206. def _align_data_index(self, data_instance, flowid, data_application=None):
  207. schema = data_instance.schema
  208. if data_application is None:
  209. # LOGGER.warning("not data_application!")
  210. # return
  211. raise ValueError("In _align_data_index, data_application should be provided.")
  212. transfer_variable = CrossValidationTransferVariable()
  213. if data_application == consts.TRAIN_DATA:
  214. transfer_id = transfer_variable.train_sid
  215. elif data_application == consts.TEST_DATA:
  216. transfer_id = transfer_variable.test_sid
  217. else:
  218. raise ValueError("In _align_data_index, data_application should be provided.")
  219. if self.role == consts.GUEST:
  220. data_sid = data_instance.mapValues(lambda v: 1)
  221. transfer_id.remote(data_sid,
  222. role=consts.HOST,
  223. idx=-1,
  224. suffix=(flowid,))
  225. LOGGER.info("remote {} to host".format(data_application))
  226. return data_instance
  227. elif self.role == consts.HOST:
  228. data_sid = transfer_id.get(idx=0,
  229. suffix=(flowid,))
  230. LOGGER.info("get {} from guest".format(data_application))
  231. join_data_insts = data_sid.join(data_instance, lambda s, d: d)
  232. join_data_insts.schema = schema
  233. return join_data_insts
  234. def evaluate(self, validate_data, fold_name, model):
  235. if validate_data is None:
  236. return
  237. eval_obj = Evaluation()
  238. # LOGGER.debug("In KFold, evaluate_param is: {}".format(self.evaluate_param.__dict__))
  239. # eval_obj._init_model(self.evaluate_param)
  240. eval_param = model.get_metrics_param()
  241. eval_param.check_single_value_default_metric()
  242. eval_obj._init_model(eval_param)
  243. eval_obj.set_tracker(model.tracker)
  244. validate_data = {fold_name: validate_data}
  245. eval_obj.fit(validate_data)
  246. eval_obj.save_data()