predict_cache.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. class PredictDataCache(object):
  19. def __init__(self):
  20. self._data_map = {}
  21. def predict_data_at(self, dataset_key, round):
  22. if dataset_key not in self._data_map:
  23. return None
  24. return self._data_map[dataset_key].data_at(round)
  25. def predict_data_last_round(self, dataset_key):
  26. if dataset_key not in self._data_map:
  27. return 0 # start from 0
  28. return self._data_map[dataset_key].get_last_round()
  29. @staticmethod
  30. def get_data_key(data):
  31. return id(data)
  32. def add_data(self, dataset_key, f, cur_boosting_round):
  33. if dataset_key not in self._data_map:
  34. self._data_map[dataset_key] = DataNode()
  35. self._data_map[dataset_key].add_data(f, cur_boosting_round)
  36. class DataNode(object):
  37. def __init__(self):
  38. self._boost_round = None
  39. self._f = None
  40. self._round_idx_map = {}
  41. self._idx = 0
  42. def get_last_round(self):
  43. return self._boost_round
  44. def data_at(self, round):
  45. if round not in self._round_idx_map:
  46. return None
  47. return self._f.mapValues(lambda f_list: f_list[self._round_idx_map[round]])
  48. def add_data(self, f, cur_round_num):
  49. if self._boost_round is None:
  50. self._boost_round = cur_round_num
  51. self._idx = 0
  52. self._f = f.mapValues(lambda pred: [pred])
  53. else:
  54. self._boost_round = cur_round_num
  55. self._idx += 1
  56. self._f = self._f.join(f, lambda pre_scores, score: pre_scores + [score])
  57. self._round_idx_map[self._boost_round] = self._idx