component_properties.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import functools
  19. import numpy as np
  20. from fate_arch.computing import is_table
  21. from federatedml.util import LOGGER
  22. class RunningFuncs(object):
  23. def __init__(self):
  24. self.todo_func_list = []
  25. self.todo_func_params = []
  26. self.save_result = []
  27. self.use_previews_result = []
  28. def add_func(self, func, params, save_result=False, use_previews=False):
  29. self.todo_func_list.append(func)
  30. self.todo_func_params.append(params)
  31. self.save_result.append(save_result)
  32. self.use_previews_result.append(use_previews)
  33. def __iter__(self):
  34. for func, params, save_result, use_previews in zip(
  35. self.todo_func_list,
  36. self.todo_func_params,
  37. self.save_result,
  38. self.use_previews_result,
  39. ):
  40. yield func, params, save_result, use_previews
  41. class DSLConfigError(ValueError):
  42. pass
  43. class ComponentProperties(object):
  44. def __init__(self):
  45. self.need_cv = False
  46. self.need_run = False
  47. self.need_stepwise = False
  48. self.has_model = False
  49. self.has_isometric_model = False
  50. self.has_train_data = False
  51. self.has_eval_data = False
  52. self.has_validate_data = False
  53. self.has_test_data = False
  54. self.has_normal_input_data = False
  55. self.role = None
  56. self.host_party_idlist = []
  57. self.local_partyid = -1
  58. self.guest_partyid = -1
  59. self.input_data_count = 0
  60. self.input_eval_data_count = 0
  61. self.caches = None
  62. self.is_warm_start = False
  63. self.has_arbiter = False
  64. def parse_caches(self, caches):
  65. self.caches = caches
  66. def parse_component_param(self, roles, param):
  67. try:
  68. need_cv = param.cv_param.need_cv
  69. except AttributeError:
  70. need_cv = False
  71. self.need_cv = need_cv
  72. try:
  73. need_run = param.need_run
  74. except AttributeError:
  75. need_run = True
  76. self.need_run = need_run
  77. LOGGER.debug("need_run: {}, need_cv: {}".format(self.need_run, self.need_cv))
  78. try:
  79. need_stepwise = param.stepwise_param.need_stepwise
  80. except AttributeError:
  81. need_stepwise = False
  82. self.need_stepwise = need_stepwise
  83. self.has_arbiter = roles["role"].get("arbiter") is not None
  84. self.role = roles["local"]["role"]
  85. self.host_party_idlist = roles["role"].get("host")
  86. self.local_partyid = roles["local"].get("party_id")
  87. self.guest_partyid = roles["role"].get("guest")
  88. if self.guest_partyid is not None:
  89. self.guest_partyid = self.guest_partyid[0]
  90. return self
  91. def parse_dsl_args(self, datasets, model):
  92. if "model" in model and model["model"] is not None:
  93. self.has_model = True
  94. if "isometric_model" in model and model["isometric_model"] is not None:
  95. self.has_isometric_model = True
  96. LOGGER.debug(f"parse_dsl_args data_sets: {datasets}")
  97. if datasets is None:
  98. return self
  99. for data_key, data_dicts in datasets.items():
  100. data_keys = list(data_dicts.keys())
  101. for data_type in ["train_data", "eval_data", "validate_data", "test_data"]:
  102. if data_type in data_keys:
  103. setattr(self, f"has_{data_type}", True)
  104. data_keys.remove(data_type)
  105. LOGGER.debug(
  106. f"[Data Parser], has_{data_type}:"
  107. f" {getattr(self, f'has_{data_type}')}"
  108. )
  109. if len(data_keys) > 0:
  110. self.has_normal_input_data = True
  111. LOGGER.debug(
  112. "[Data Parser], has_normal_data: {}".format(self.has_normal_input_data)
  113. )
  114. if self.has_eval_data:
  115. if self.has_validate_data or self.has_test_data:
  116. raise DSLConfigError(
  117. "eval_data input should not be configured simultaneously"
  118. " with validate_data or test_data"
  119. )
  120. # self._abnormal_dsl_config_detect()
  121. if self.has_model and self.has_train_data:
  122. self.is_warm_start = True
  123. return self
  124. def _abnormal_dsl_config_detect(self):
  125. if self.has_validate_data:
  126. if not self.has_train_data:
  127. raise DSLConfigError(
  128. "validate_data should be configured simultaneously"
  129. " with train_data"
  130. )
  131. if self.has_train_data:
  132. if self.has_normal_input_data or self.has_test_data:
  133. raise DSLConfigError(
  134. "train_data input should not be configured simultaneously"
  135. " with data or test_data"
  136. )
  137. if self.has_normal_input_data:
  138. if self.has_train_data or self.has_validate_data or self.has_test_data:
  139. raise DSLConfigError(
  140. "When data input has been configured, train_data, "
  141. "validate_data or test_data should not be configured."
  142. )
  143. if self.has_test_data:
  144. if not self.has_model:
  145. raise DSLConfigError(
  146. "When test_data input has been configured, model "
  147. "input should be configured too."
  148. )
  149. if self.need_cv or self.need_stepwise:
  150. if not self.has_train_data:
  151. raise DSLConfigError(
  152. "Train_data should be configured in cross-validate "
  153. "task or stepwise task"
  154. )
  155. if (
  156. self.has_validate_data
  157. or self.has_normal_input_data
  158. or self.has_test_data
  159. ):
  160. raise DSLConfigError(
  161. "Train_data should be set only if it is a cross-validate "
  162. "task or a stepwise task"
  163. )
  164. if self.has_model or self.has_isometric_model:
  165. raise DSLConfigError(
  166. "In cross-validate task or stepwise task, model "
  167. "or isometric_model should not be configured"
  168. )
  169. def extract_input_data(self, datasets, model):
  170. model_data = {}
  171. data = {}
  172. LOGGER.debug(f"Input data_sets: {datasets}")
  173. for cpn_name, data_dict in datasets.items():
  174. for data_type in ["train_data", "eval_data", "validate_data", "test_data"]:
  175. if data_type in data_dict:
  176. d_table = data_dict.get(data_type)
  177. model_data[data_type] = model.obtain_data(d_table)
  178. del data_dict[data_type]
  179. if len(data_dict) > 0:
  180. LOGGER.debug(f"data_dict: {data_dict}")
  181. for k, v in data_dict.items():
  182. data_list = model.obtain_data(v)
  183. LOGGER.debug(f"data_list: {data_list}")
  184. if isinstance(data_list, list):
  185. for i, data_i in enumerate(data_list):
  186. data[".".join([cpn_name, k, str(i)])] = data_i
  187. else:
  188. data[".".join([cpn_name, k])] = data_list
  189. train_data = model_data.get("train_data")
  190. validate_data = None
  191. if self.has_train_data:
  192. if self.has_eval_data:
  193. validate_data = model_data.get("eval_data")
  194. elif self.has_validate_data:
  195. validate_data = model_data.get("validate_data")
  196. test_data = None
  197. if self.has_test_data:
  198. test_data = model_data.get("test_data")
  199. self.has_test_data = True
  200. elif self.has_eval_data and not self.has_train_data:
  201. test_data = model_data.get("eval_data")
  202. self.has_test_data = True
  203. if validate_data or (self.has_train_data and self.has_eval_data):
  204. self.has_validate_data = True
  205. if self.has_train_data and is_table(train_data):
  206. self.input_data_count = train_data.count()
  207. elif self.has_normal_input_data:
  208. for data_key, data_table in data.items():
  209. if is_table(data_table):
  210. self.input_data_count = data_table.count()
  211. if self.has_validate_data and is_table(validate_data):
  212. self.input_eval_data_count = validate_data.count()
  213. self._abnormal_dsl_config_detect()
  214. LOGGER.debug(
  215. f"train_data: {train_data}, validate_data: {validate_data}, "
  216. f"test_data: {test_data}, data: {data}"
  217. )
  218. return train_data, validate_data, test_data, data
  219. def warm_start_process(self, running_funcs, model, train_data, validate_data, schema=None):
  220. if schema is None:
  221. for d in [train_data, validate_data]:
  222. if d is not None:
  223. schema = d.schema
  224. break
  225. running_funcs = self._train_process(running_funcs, model, train_data, validate_data,
  226. test_data=None, schema=schema)
  227. return running_funcs
  228. def _train_process(self, running_funcs, model, train_data, validate_data, test_data, schema):
  229. if self.has_train_data and self.has_validate_data:
  230. running_funcs.add_func(model.set_flowid, ['fit'])
  231. running_funcs.add_func(model.fit, [train_data, validate_data])
  232. running_funcs.add_func(model.set_flowid, ['validate'])
  233. running_funcs.add_func(model.predict, [train_data], save_result=True)
  234. running_funcs.add_func(model.set_flowid, ['predict'])
  235. running_funcs.add_func(model.predict, [validate_data], save_result=True)
  236. running_funcs.add_func(self.union_data, ["train", "validate"], use_previews=True, save_result=True)
  237. running_funcs.add_func(model.set_predict_data_schema, [schema],
  238. use_previews=True, save_result=True)
  239. elif self.has_train_data:
  240. running_funcs.add_func(model.set_flowid, ['fit'])
  241. running_funcs.add_func(model.fit, [train_data])
  242. running_funcs.add_func(model.set_flowid, ['validate'])
  243. running_funcs.add_func(model.predict, [train_data], save_result=True)
  244. running_funcs.add_func(self.union_data, ["train"], use_previews=True, save_result=True)
  245. running_funcs.add_func(model.set_predict_data_schema, [schema],
  246. use_previews=True, save_result=True)
  247. elif self.has_test_data:
  248. running_funcs.add_func(model.set_flowid, ['predict'])
  249. running_funcs.add_func(model.predict, [test_data], save_result=True)
  250. running_funcs.add_func(self.union_data, ["predict"], use_previews=True, save_result=True)
  251. running_funcs.add_func(model.set_predict_data_schema, [schema],
  252. use_previews=True, save_result=True)
  253. return running_funcs
  254. def extract_running_rules(self, datasets, models, cpn):
  255. # train_data, eval_data, data = self.extract_input_data(args)
  256. train_data, validate_data, test_data, data = self.extract_input_data(
  257. datasets, cpn
  258. )
  259. running_funcs = RunningFuncs()
  260. schema = None
  261. for d in [train_data, validate_data, test_data]:
  262. if d is not None:
  263. schema = d.schema
  264. break
  265. if not self.need_run:
  266. running_funcs.add_func(cpn.pass_data, [data], save_result=True)
  267. return running_funcs
  268. if self.need_cv:
  269. running_funcs.add_func(cpn.cross_validation, [train_data], save_result=True)
  270. return running_funcs
  271. if self.need_stepwise:
  272. running_funcs.add_func(cpn.stepwise, [train_data], save_result=True)
  273. running_funcs.add_func(self.union_data, ["train"], use_previews=True, save_result=True)
  274. running_funcs.add_func(cpn.set_predict_data_schema, [schema],
  275. use_previews=True, save_result=True)
  276. return running_funcs
  277. if self.has_model or self.has_isometric_model:
  278. running_funcs.add_func(cpn.load_model, [models])
  279. if self.is_warm_start:
  280. return self.warm_start_process(running_funcs, cpn, train_data, validate_data, schema)
  281. running_funcs = self._train_process(running_funcs, cpn, train_data, validate_data, test_data, schema)
  282. if self.has_normal_input_data and not self.has_model:
  283. running_funcs.add_func(cpn.extract_data, [data], save_result=True)
  284. running_funcs.add_func(cpn.set_flowid, ['fit'])
  285. running_funcs.add_func(cpn.fit, [], use_previews=True, save_result=True)
  286. if self.has_normal_input_data and self.has_model:
  287. running_funcs.add_func(cpn.extract_data, [data], save_result=True)
  288. running_funcs.add_func(cpn.set_flowid, ['transform'])
  289. running_funcs.add_func(cpn.transform, [], use_previews=True, save_result=True)
  290. return running_funcs
  291. @staticmethod
  292. def union_data(previews_data, name_list):
  293. if len(previews_data) == 0:
  294. return None
  295. if any([x is None for x in previews_data]):
  296. return None
  297. assert len(previews_data) == len(name_list)
  298. def _append_name(value, name):
  299. inst = copy.deepcopy(value)
  300. if isinstance(inst.features, list):
  301. inst.features.append(name)
  302. else:
  303. inst.features = np.append(inst.features, name)
  304. return inst
  305. result_data = None
  306. for data, name in zip(previews_data, name_list):
  307. # LOGGER.debug("before mapValues, one data: {}".format(data.first()))
  308. f = functools.partial(_append_name, name=name)
  309. data = data.mapValues(f)
  310. # LOGGER.debug("after mapValues, one data: {}".format(data.first()))
  311. if result_data is None:
  312. result_data = data
  313. else:
  314. LOGGER.debug(
  315. f"Before union, t1 count: {result_data.count()}, t2 count: {data.count()}"
  316. )
  317. result_data = result_data.union(data)
  318. LOGGER.debug(f"After union, result count: {result_data.count()}")
  319. # LOGGER.debug("before out loop, one data: {}".format(result_data.first()))
  320. return result_data
  321. def set_union_func(self, func):
  322. self.union_data = func