sampler.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import functools
  17. import math
  18. import random
  19. from sklearn.utils import resample
  20. from fate_arch.session import computing_session as session
  21. from federatedml.model_base import Metric
  22. from federatedml.model_base import MetricMeta
  23. from federatedml.model_base import ModelBase
  24. from federatedml.param.sample_param import SampleParam
  25. from federatedml.transfer_variable.transfer_class.sample_transfer_variable import SampleTransferVariable
  26. from federatedml.util import LOGGER
  27. from federatedml.util import consts
  28. from federatedml.util.schema_check import assert_schema_consistent
  29. from fate_arch.common.base_utils import fate_uuid
  30. class RandomSampler(object):
  31. """
  32. Random Sampling Method
  33. Parameters
  34. ----------
  35. fraction : None or float, sampling ratio, default: 0.1
  36. random_state: int, RandomState instance or None, optional, default: None
  37. method: str, supported "upsample", "downsample" only in this version, default: "downsample"
  38. """
  39. def __init__(self, fraction=0.1, random_state=None, method="downsample"):
  40. self.fraction = fraction
  41. self.random_state = random_state
  42. self.method = method
  43. self.tracker = None
  44. self._summary_buf = {}
  45. def set_tracker(self, tracker):
  46. self.tracker = tracker
  47. def sample(self, data_inst, sample_ids=None):
  48. """
  49. Interface to call random sample method
  50. Parameters
  51. ----------
  52. data_inst : Table
  53. The input data
  54. sample_ids : None or list
  55. if None, will sample data from the class instance's parameters,
  56. otherwise, it will be sample transform process, which means use the samples_ids to generate data
  57. Returns
  58. -------
  59. new_data_inst: Table
  60. the output sample data, same format with input
  61. sample_ids: list, return only if sample_ids is None
  62. """
  63. if sample_ids is None:
  64. new_data_inst, sample_ids = self.__sample(data_inst)
  65. return new_data_inst, sample_ids
  66. else:
  67. new_data_inst = self.__sample(data_inst, sample_ids)
  68. return new_data_inst
  69. def __sample(self, data_inst, sample_ids=None):
  70. """
  71. Random sample method, a line's occur probability is decide by fraction
  72. support down sample and up sample
  73. if use down sample: should give a float ratio between [0, 1]
  74. otherwise: should give a float ratio larger than 1.0
  75. Parameters
  76. ----------
  77. data_inst : Table
  78. The input data
  79. sample_ids : None or list
  80. if None, will sample data from the class instance's parameters,
  81. otherwise, it will be sample transform process, which means use the samples_ids to generate data
  82. Returns
  83. -------
  84. new_data_inst: Table
  85. the output sample data, same format with input
  86. sample_ids: list, return only if sample_ids is None
  87. """
  88. LOGGER.info("start to run random sampling")
  89. return_sample_ids = False
  90. if self.method == "downsample":
  91. if sample_ids is None:
  92. return_sample_ids = True
  93. idset = [key for key, value in data_inst.mapValues(lambda val: None).collect()]
  94. if self.fraction < 0 or self.fraction > 1:
  95. raise ValueError("sapmle fractions should be a numeric number between 0 and 1inclusive")
  96. sample_num = max(1, int(self.fraction * len(idset)))
  97. sample_ids = resample(idset,
  98. replace=False,
  99. n_samples=sample_num,
  100. random_state=self.random_state)
  101. sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))),
  102. include_key=True,
  103. partition=data_inst.partitions)
  104. new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)
  105. callback(self.tracker, "random", [Metric("count", new_data_inst.count())], summary_dict=self._summary_buf)
  106. if return_sample_ids:
  107. return new_data_inst, sample_ids
  108. else:
  109. return new_data_inst
  110. elif self.method == "upsample":
  111. data_set = list(data_inst.collect())
  112. idset = [key for (key, value) in data_set]
  113. id_maps = dict(zip(idset, range(len(idset))))
  114. if sample_ids is None:
  115. return_sample_ids = True
  116. if self.fraction <= 0:
  117. raise ValueError("sample fractions should be a numeric number large than 0")
  118. sample_num = int(self.fraction * len(idset))
  119. sample_ids = resample(idset,
  120. replace=True,
  121. n_samples=sample_num,
  122. random_state=self.random_state)
  123. new_data = []
  124. for i in range(len(sample_ids)):
  125. index = id_maps[sample_ids[i]]
  126. new_data.append((i, data_set[index][1]))
  127. new_data_inst = session.parallelize(new_data,
  128. include_key=True,
  129. partition=data_inst.partitions)
  130. callback(self.tracker, "random", [Metric("count", new_data_inst.count())], summary_dict=self._summary_buf)
  131. if return_sample_ids:
  132. return new_data_inst, sample_ids
  133. else:
  134. return new_data_inst
  135. else:
  136. raise ValueError("random sampler not support method {} yet".format(self.method))
  137. def get_summary(self):
  138. return self._summary_buf
  139. class StratifiedSampler(object):
  140. """
  141. Stratified Sampling Method
  142. Parameters
  143. ----------
  144. fractions : None or list of (category, sample ratio) tuple,
  145. sampling ratios of each category, default: None
  146. e.g.
  147. [(0, 0.5), (1, 0.1]) in down sample, [(1, 1.5), (0, 1.8)], where 0\1 are the the occurred category.
  148. random_state: int, RandomState instance or None, optional, default: None
  149. method: str, supported "upsample", "downsample", default: "downsample"
  150. """
  151. def __init__(self, fractions=None, random_state=None, method="downsample"):
  152. self.fractions = fractions
  153. self.label_mapping = {}
  154. self.labels = []
  155. if fractions:
  156. for (label, frac) in fractions:
  157. self.label_mapping[label] = len(self.labels)
  158. self.labels.append(label)
  159. # self.label_mapping = [label for (label, frac) in fractions]
  160. self.random_state = random_state
  161. self.method = method
  162. self.tracker = None
  163. self._summary_buf = {}
  164. def set_tracker(self, tracker):
  165. self.tracker = tracker
  166. def sample(self, data_inst, sample_ids=None):
  167. """
  168. Interface to call stratified sample method
  169. Parameters
  170. ----------
  171. data_inst : Table
  172. The input data
  173. sample_ids : None or list
  174. if None, will sample data from the class instance's key by sample parameters,
  175. otherwise, it will be sample transform process, which means use the samples_ids to generate data
  176. Returns
  177. -------
  178. new_data_inst: Table
  179. the output sample data, same format with input
  180. sample_ids: list, return only if sample_ids is None
  181. """
  182. if sample_ids is None:
  183. new_data_inst, sample_ids = self.__sample(data_inst)
  184. return new_data_inst, sample_ids
  185. else:
  186. new_data_inst = self.__sample(data_inst, sample_ids)
  187. return new_data_inst
  188. def __sample(self, data_inst, sample_ids=None):
  189. """
  190. Stratified sample method, a line's occur probability is decide by fractions
  191. Input should be Table, every line should be an instance object with label
  192. To use this method, a list of ratio should be give, and the list length
  193. equals to the number of distinct labels
  194. support down sample and up sample
  195. if use down sample: should give a list of (category, ratio), where ratio is between [0, 1]
  196. otherwise: should give a list (category, ratio), where the float ratio should no less than 1.0
  197. Parameters
  198. ----------
  199. data_inst : Table
  200. The input data
  201. sample_ids : None or list
  202. if None, will sample data from the class instance's parameters,
  203. otherwise, it will be sample transform process, which means use the samples_ids the generate data
  204. Returns
  205. -------
  206. new_data_inst: Table
  207. the output sample data, sample format with input
  208. sample_ids: list, return only if sample_ids is None
  209. """
  210. LOGGER.info("start to run stratified sampling")
  211. return_sample_ids = False
  212. if self.method == "downsample":
  213. if sample_ids is None:
  214. idset = [[] for i in range(len(self.fractions))]
  215. for label, fraction in self.fractions:
  216. if fraction < 0 or fraction > 1:
  217. raise ValueError("sapmle fractions should be a numeric number between 0 and 1inclusive")
  218. return_sample_ids = True
  219. for key, inst in data_inst.collect():
  220. label = inst.label
  221. if label not in self.label_mapping:
  222. raise ValueError("label not specify sample rate! check it please")
  223. idset[self.label_mapping[label]].append(key)
  224. sample_ids = []
  225. callback_sample_metrics = []
  226. callback_original_metrics = []
  227. for i in range(len(idset)):
  228. label_name = self.labels[i]
  229. callback_original_metrics.append(Metric(label_name, len(idset[i])))
  230. if idset[i]:
  231. sample_num = max(1, int(self.fractions[i][1] * len(idset[i])))
  232. _sample_ids = resample(idset[i],
  233. replace=False,
  234. n_samples=sample_num,
  235. random_state=self.random_state)
  236. sample_ids.extend(_sample_ids)
  237. callback_sample_metrics.append(Metric(label_name, len(_sample_ids)))
  238. else:
  239. callback_sample_metrics.append(Metric(label_name, 0))
  240. random.shuffle(sample_ids)
  241. callback(
  242. self.tracker,
  243. "stratified",
  244. callback_sample_metrics,
  245. callback_original_metrics,
  246. self._summary_buf)
  247. sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))),
  248. include_key=True,
  249. partition=data_inst.partitions)
  250. new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)
  251. if return_sample_ids:
  252. return new_data_inst, sample_ids
  253. else:
  254. return new_data_inst
  255. elif self.method == "upsample":
  256. data_set = list(data_inst.collect())
  257. ids = [key for (key, inst) in data_set]
  258. id_maps = dict(zip(ids, range(len(ids))))
  259. return_sample_ids = False
  260. if sample_ids is None:
  261. idset = [[] for i in range(len(self.fractions))]
  262. for label, fraction in self.fractions:
  263. if fraction <= 0:
  264. raise ValueError("sapmle fractions should be a numeric number greater than 0")
  265. for key, inst in data_set:
  266. label = inst.label
  267. if label not in self.label_mapping:
  268. raise ValueError("label not specify sample rate! check it please")
  269. idset[self.label_mapping[label]].append(key)
  270. return_sample_ids = True
  271. sample_ids = []
  272. callback_sample_metrics = []
  273. callback_original_metrics = []
  274. for i in range(len(idset)):
  275. label_name = self.labels[i]
  276. callback_original_metrics.append(Metric(label_name, len(idset[i])))
  277. if idset[i]:
  278. sample_num = max(1, int(self.fractions[i][1] * len(idset[i])))
  279. _sample_ids = resample(idset[i],
  280. replace=True,
  281. n_samples=sample_num,
  282. random_state=self.random_state)
  283. sample_ids.extend(_sample_ids)
  284. callback_sample_metrics.append(Metric(label_name, len(_sample_ids)))
  285. else:
  286. callback_sample_metrics.append(Metric(label_name, 0))
  287. random.shuffle(sample_ids)
  288. callback(
  289. self.tracker,
  290. "stratified",
  291. callback_sample_metrics,
  292. callback_original_metrics,
  293. self._summary_buf)
  294. new_data = []
  295. for i in range(len(sample_ids)):
  296. index = id_maps[sample_ids[i]]
  297. new_data.append((i, data_set[index][1]))
  298. new_data_inst = session.parallelize(new_data,
  299. include_key=True,
  300. partition=data_inst.partitions)
  301. if return_sample_ids:
  302. return new_data_inst, sample_ids
  303. else:
  304. return new_data_inst
  305. else:
  306. raise ValueError("Stratified sampler not support method {} yet".format(self.method))
  307. def get_summary(self):
  308. return self._summary_buf
  309. class ExactSampler(object):
  310. """
  311. Exact Sampling Method
  312. Parameters
  313. ----------
  314. """
  315. def __init__(self):
  316. self.tracker = None
  317. self._summary_buf = {}
  318. def set_tracker(self, tracker):
  319. self.tracker = tracker
  320. def get_sample_ids(self, data_inst):
  321. original_sample_count = data_inst.count()
  322. non_zero_data_inst = data_inst.filter(lambda k, v: v.weight > consts.FLOAT_ZERO)
  323. non_zero_sample_count = data_inst.count()
  324. if original_sample_count != non_zero_sample_count:
  325. sample_diff = original_sample_count - non_zero_sample_count
  326. LOGGER.warning(f"{sample_diff} zero-weighted sample encountered, will be discarded in final result.")
  327. def __generate_new_ids(v):
  328. if v.inst_id is None:
  329. raise ValueError(f"To sample with `exact_by_weight` mode, instances must have match id."
  330. f"Please check.")
  331. new_key_num = math.ceil(v.weight)
  332. new_sample_id_list = [fate_uuid() for _ in range(new_key_num)]
  333. return new_sample_id_list
  334. sample_ids = non_zero_data_inst.mapValues(lambda v: __generate_new_ids(v))
  335. return sample_ids
  336. def sample(self, data_inst, sample_ids=None):
  337. """
  338. Interface to call stratified sample method
  339. Parameters
  340. ----------
  341. data_inst : Table
  342. The input data
  343. sample_ids : Table
  344. use the samples_ids to generate data
  345. Returns
  346. -------
  347. new_data_inst: Table
  348. the output sample data, same format with input
  349. """
  350. LOGGER.info("start to generate exact sampling result")
  351. new_data_inst = self.__sample(data_inst, sample_ids)
  352. return new_data_inst
  353. def __sample(self, data_inst, sample_ids):
  354. """
  355. Exact sample method, duplicate samples by corresponding weight:
  356. if weight <= 0, discard sample; if round(weight) == 1, keep one,
  357. else duplicate round(weight) copies of sample
  358. Parameters
  359. ----------
  360. data_inst : Table
  361. The input data
  362. sample_ids : Table
  363. use the samples_ids the generate data
  364. Returns
  365. -------
  366. new_data_inst: Table
  367. the output sample data, sample format with input
  368. """
  369. sample_ids_map = data_inst.join(sample_ids, lambda v, ids: (v, ids))
  370. def __sample_new_id(k, v_id_map):
  371. v, id_map = v_id_map
  372. return [(new_id, v) for new_id in id_map]
  373. new_data_inst = sample_ids_map.flatMap(functools.partial(__sample_new_id))
  374. data_count = new_data_inst.count()
  375. if data_count is None:
  376. data_count = 0
  377. LOGGER.warning(f"All data instances discarded. Please check weight.")
  378. callback(self.tracker, "exact_by_weight", [Metric("count", data_count)], summary_dict=self._summary_buf)
  379. return new_data_inst
  380. def get_summary(self):
  381. return self._summary_buf
  382. class Sampler(ModelBase):
  383. """
  384. Sampling Object
  385. Parameters
  386. ----------
  387. sample_param : object, self-define sample parameters,
  388. define in federatedml.param.sample_param
  389. """
  390. def __init__(self):
  391. super(Sampler, self).__init__()
  392. self.task_type = None
  393. # self.task_role = None
  394. self.flowid = 0
  395. self.model_param = SampleParam()
  396. def _init_model(self, sample_param):
  397. if sample_param.mode == "random":
  398. self.sampler = RandomSampler(sample_param.fractions,
  399. sample_param.random_state,
  400. sample_param.method)
  401. self.sampler.set_tracker(self.tracker)
  402. elif sample_param.mode == "stratified":
  403. self.sampler = StratifiedSampler(sample_param.fractions,
  404. sample_param.random_state,
  405. sample_param.method)
  406. self.sampler.set_tracker(self.tracker)
  407. elif sample_param.mode == "exact_by_weight":
  408. self.sampler = ExactSampler()
  409. self.sampler.set_tracker(self.tracker)
  410. else:
  411. raise ValueError("{} sampler not support yet".format(sample_param.mde))
  412. self.task_type = sample_param.task_type
  413. def _init_role(self, component_parameters):
  414. self.task_role = component_parameters["local"]["role"]
  415. def sample(self, data_inst, sample_ids=None):
  416. """
  417. Entry to use sample method
  418. Parameters
  419. ----------
  420. data_inst : Table
  421. The input data
  422. sample_ids : None or list
  423. if None, will sample data from the class instance's parameters,
  424. otherwise, it will be sample transform process, which means use the samples_ids the generate data
  425. Returns
  426. -------
  427. sample_data: Table
  428. the output sample data, same format with input
  429. """
  430. ori_schema = data_inst.schema
  431. sample_data = self.sampler.sample(data_inst, sample_ids)
  432. self.set_summary(self.sampler.get_summary())
  433. try:
  434. if len(sample_data) == 2:
  435. sample_data[0].schema = ori_schema
  436. except BaseException:
  437. sample_data.schema = ori_schema
  438. return sample_data
  439. def set_flowid(self, flowid="samole"):
  440. self.flowid = flowid
  441. def sync_sample_ids(self, sample_ids):
  442. transfer_inst = SampleTransferVariable()
  443. transfer_inst.sample_ids.remote(sample_ids,
  444. role="host",
  445. suffix=(self.flowid,))
  446. def recv_sample_ids(self):
  447. transfer_inst = SampleTransferVariable()
  448. sample_ids = transfer_inst.sample_ids.get(idx=0,
  449. suffix=(self.flowid,))
  450. return sample_ids
  451. def run_sample(self, data_inst, task_type, task_role):
  452. """
  453. Sample running entry
  454. Parameters
  455. ----------
  456. data_inst : Table
  457. The input data
  458. task_type : "homo" or "hetero"
  459. if task_type is "homo", it will sample standalone
  460. if task_type is "heterl": then sampling will be done in one side, after that
  461. the side sync the sample ids to another side to generated the same sample result
  462. task_role: "guest" or "host":
  463. only consider this parameter when task_type is "hetero"
  464. if task_role is "guest", it will firstly sample ids, and sync it to "host"
  465. to generate data instances with sample ids
  466. if task_role is "host": it will firstly get the sample ids result of "guest",
  467. then generate sample data by the receiving ids
  468. Returns
  469. -------
  470. sample_data_inst: Table
  471. the output sample data, same format with input
  472. """
  473. LOGGER.info("begin to run sampling process")
  474. if task_type not in [consts.HOMO, consts.HETERO]:
  475. raise ValueError("{} task type not support yet".format(task_type))
  476. if task_type == consts.HOMO:
  477. return self.sample(data_inst)[0]
  478. elif task_type == consts.HETERO:
  479. if task_role == consts.GUEST:
  480. if self.model_param.mode == "exact_by_weight":
  481. LOGGER.info("start to run exact sampling")
  482. sample_ids = self.sampler.get_sample_ids(data_inst)
  483. self.sync_sample_ids(sample_ids)
  484. sample_data_inst = self.sample(data_inst, sample_ids)
  485. else:
  486. sample_data_inst, sample_ids = self.sample(data_inst)
  487. self.sync_sample_ids(sample_ids)
  488. elif task_role == consts.HOST:
  489. sample_ids = self.recv_sample_ids()
  490. sample_data_inst = self.sample(data_inst, sample_ids)
  491. else:
  492. raise ValueError("{} role not support yet".format(task_role))
  493. return sample_data_inst
  494. @assert_schema_consistent
  495. def fit(self, data_inst):
  496. return self.run_sample(data_inst, self.task_type, self.role)
  497. def transform(self, data_inst):
  498. return self.run_sample(data_inst, self.task_type, self.role)
  499. def check_consistency(self):
  500. pass
  501. def save_data(self):
  502. return self.data_output
  503. def callback(tracker, method, callback_metrics, other_metrics=None, summary_dict=None):
  504. LOGGER.debug("callback: method is {}".format(method))
  505. if method == "random":
  506. tracker.log_metric_data("sample_count",
  507. "random",
  508. callback_metrics)
  509. tracker.set_metric_meta("sample_count",
  510. "random",
  511. MetricMeta(name="sample_count",
  512. metric_type="SAMPLE_TEXT"))
  513. summary_dict["sample_count"] = callback_metrics[0].value
  514. elif method == "stratified":
  515. LOGGER.debug(
  516. "callback: name {}, namespace {}, metrics_data {}".format("sample_count", "stratified", callback_metrics))
  517. tracker.log_metric_data("sample_count",
  518. "stratified",
  519. callback_metrics)
  520. tracker.set_metric_meta("sample_count",
  521. "stratified",
  522. MetricMeta(name="sample_count",
  523. metric_type="SAMPLE_TABLE"))
  524. tracker.log_metric_data("original_count",
  525. "stratified",
  526. other_metrics)
  527. tracker.set_metric_meta("original_count",
  528. "stratified",
  529. MetricMeta(name="original_count",
  530. metric_type="SAMPLE_TABLE"))
  531. summary_dict["sample_count"] = {}
  532. for sample_metric in callback_metrics:
  533. summary_dict["sample_count"][sample_metric.key] = sample_metric.value
  534. summary_dict["original_count"] = {}
  535. for sample_metric in other_metrics:
  536. summary_dict["original_count"][sample_metric.key] = sample_metric.value
  537. else:
  538. LOGGER.debug(
  539. f"callback: metrics_data {callback_metrics}, summary dict: {summary_dict}")
  540. tracker.log_metric_data("sample_count",
  541. "exact_by_weight",
  542. callback_metrics)
  543. tracker.set_metric_meta("sample_count",
  544. "exact_by_weight",
  545. MetricMeta(name="sample_count",
  546. metric_type="SAMPLE_TEXT"))
  547. summary_dict["sample_count"] = callback_metrics[0].value