123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- class PredictDataCache(object):
- def __init__(self):
- self._data_map = {}
- def predict_data_at(self, dataset_key, round):
- if dataset_key not in self._data_map:
- return None
- return self._data_map[dataset_key].data_at(round)
- def predict_data_last_round(self, dataset_key):
- if dataset_key not in self._data_map:
- return 0
- return self._data_map[dataset_key].get_last_round()
- @staticmethod
- def get_data_key(data):
- return id(data)
- def add_data(self, dataset_key, f, cur_boosting_round):
- if dataset_key not in self._data_map:
- self._data_map[dataset_key] = DataNode()
- self._data_map[dataset_key].add_data(f, cur_boosting_round)
- class DataNode(object):
- def __init__(self):
- self._boost_round = None
- self._f = None
- self._round_idx_map = {}
- self._idx = 0
- def get_last_round(self):
- return self._boost_round
- def data_at(self, round):
- if round not in self._round_idx_map:
- return None
- return self._f.mapValues(lambda f_list: f_list[self._round_idx_map[round]])
- def add_data(self, f, cur_round_num):
- if self._boost_round is None:
- self._boost_round = cur_round_num
- self._idx = 0
- self._f = f.mapValues(lambda pred: [pred])
- else:
- self._boost_round = cur_round_num
- self._idx += 1
- self._f = self._f.join(f, lambda pre_scores, score: pre_scores + [score])
- self._round_idx_map[self._boost_round] = self._idx
|