kmeans_model_base.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from fate_arch.common import log
  18. from federatedml.model_base import ModelBase
  19. from federatedml.param.hetero_kmeans_param import KmeansParam
  20. from federatedml.protobuf.generated import hetero_kmeans_meta_pb2, hetero_kmeans_param_pb2
  21. from federatedml.transfer_variable.transfer_class.hetero_kmeans_transfer_variable import HeteroKmeansTransferVariable
  22. from federatedml.util import abnormal_detection
  23. from federatedml.feature.instance import Instance
  24. from federatedml.util import consts
  25. import functools
  26. LOGGER = log.getLogger()
  27. class BaseKmeansModel(ModelBase):
  28. def __init__(self):
  29. super(BaseKmeansModel, self).__init__()
  30. self.model_param = KmeansParam()
  31. self.n_iter_ = 0
  32. self.k = 0
  33. self.max_iter = 0
  34. self.tol = 0
  35. self.random_stat = None
  36. self.iter = iter
  37. self.centroid_list = None
  38. self.cluster_result = None
  39. self.transfer_variable = HeteroKmeansTransferVariable()
  40. self.model_name = 'toSet'
  41. self.model_param_name = 'HeteroKmeansParam'
  42. self.model_meta_name = 'HeteroKmeansMeta'
  43. self.header = None
  44. self.reset_union()
  45. self.is_converged = False
  46. self.cluster_detail = None
  47. self.cluster_count = None
  48. self.aggregator = None
  49. def _init_model(self, params):
  50. self.model_param = params
  51. self.k = params.k
  52. self.max_iter = params.max_iter
  53. self.tol = params.tol
  54. self.random_stat = params.random_stat
  55. # self.aggregator.register_aggregator(self.transfer_variable)
  56. def get_header(self, data_instances):
  57. if self.header is not None:
  58. return self.header
  59. return data_instances.schema.get("header")
  60. def _get_meta(self):
  61. meta_protobuf_obj = hetero_kmeans_meta_pb2.KmeansModelMeta(k=self.model_param.k,
  62. tol=self.model_param.tol,
  63. max_iter=self.max_iter)
  64. return meta_protobuf_obj
  65. def _get_param(self):
  66. header = self.header
  67. LOGGER.debug("In get_param, header: {}".format(header))
  68. if header is None:
  69. param_protobuf_obj = hetero_kmeans_param_pb2.KmeansModelParam()
  70. return param_protobuf_obj
  71. cluster_detail = [hetero_kmeans_param_pb2.Clusterdetail(cluster=cluster) for cluster in self.cluster_count]
  72. centroid_detail = [hetero_kmeans_param_pb2.Centroiddetail(centroid=centroid) for centroid in self.centroid_list]
  73. param_protobuf_obj = hetero_kmeans_param_pb2.KmeansModelParam(count_of_clusters=self.k,
  74. max_interation=self.n_iter_,
  75. converged=self.is_converged,
  76. cluster_detail=cluster_detail,
  77. centroid_detail=centroid_detail,
  78. header=self.header)
  79. return param_protobuf_obj
  80. def export_model(self):
  81. meta_obj = self._get_meta()
  82. param_obj = self._get_param()
  83. result = {
  84. self.model_meta_name: meta_obj,
  85. self.model_param_name: param_obj
  86. }
  87. return result
  88. def count(self, iterator):
  89. count_result = dict()
  90. for k, v in iterator:
  91. if v not in count_result:
  92. count_result[v] = 1
  93. else:
  94. count_result[v] += 1
  95. return count_result
  96. @staticmethod
  97. def sum_dict(d1, d2):
  98. temp = dict()
  99. for key in d1.keys() | d2.keys():
  100. temp[key] = sum([d.get(key, 0) for d in (d1, d2)])
  101. return temp
  102. def _abnormal_detection(self, data_instances):
  103. """
  104. Make sure input data_instances is valid.
  105. """
  106. abnormal_detection.empty_table_detection(data_instances)
  107. abnormal_detection.empty_feature_detection(data_instances)
  108. def load_model(self, model_dict):
  109. param_obj = list(model_dict.get('model').values())[0].get(self.model_param_name)
  110. meta_obj = list(model_dict.get('model').values())[0].get(self.model_meta_name)
  111. self.k = meta_obj.k
  112. self.centroid_list = list(param_obj.centroid_detail)
  113. for idx, c in enumerate(self.centroid_list):
  114. self.centroid_list[idx] = list(c.centroid)
  115. self.cluster_count = list(param_obj.cluster_detail)
  116. for idx, c in enumerate(self.cluster_count):
  117. self.cluster_count[idx] = list(c.cluster)
  118. # self.header = list(result_obj.header)
  119. # if self.header is None:
  120. # return
  121. def reset_union(self):
  122. def _add_name(inst, name):
  123. return Instance(features=inst.features + [name], inst_id=inst.inst_id)
  124. def kmeans_union(previews_data, name_list):
  125. if len(previews_data) == 0:
  126. return None
  127. if any([x is None for x in previews_data]):
  128. return None
  129. # assert len(previews_data) == len(name_list)
  130. if self.role == consts.ARBITER:
  131. data_outputs = []
  132. for data_output, name in zip(previews_data, name_list):
  133. f = functools.partial(_add_name, name=name)
  134. data_output1 = data_output[0].mapValues(f)
  135. data_output2 = data_output[1].mapValues(f)
  136. data_outputs.append([data_output1, data_output2])
  137. else:
  138. data_output1 = sub_union(previews_data, name_list)
  139. data_outputs = [data_output1, None]
  140. return data_outputs
  141. def sub_union(data_output, name_list):
  142. result_data = None
  143. for data, name in zip(data_output, name_list):
  144. # LOGGER.debug("before mapValues, one data: {}".format(data.first()))
  145. f = functools.partial(_add_name, name=name)
  146. data = data.mapValues(f)
  147. # LOGGER.debug("after mapValues, one data: {}".format(data.first()))
  148. if result_data is None:
  149. result_data = data
  150. else:
  151. LOGGER.debug(f"Before union, t1 count: {result_data.count()}, t2 count: {data.count()}")
  152. result_data = result_data.union(data)
  153. LOGGER.debug(f"After union, result count: {result_data.count()}")
  154. # LOGGER.debug("before out loop, one data: {}".format(result_data.first()))
  155. return result_data
  156. self.component_properties.set_union_func(kmeans_union)
  157. def set_predict_data_schema(self, predict_datas, schemas):
  158. if predict_datas is None:
  159. return None, None
  160. predict_data = predict_datas[0][0]
  161. schema = schemas[0]
  162. if self.role == consts.ARBITER:
  163. data_output1 = predict_data[0]
  164. data_output2 = predict_data[1]
  165. if data_output1 is not None:
  166. data_output1.schema = {
  167. "header": ["cluster_sample_count", "cluster_inner_dist", "inter_cluster_dist", "type"],
  168. "sid": "cluster_index",
  169. "content_type": "cluster_result"
  170. }
  171. if data_output2 is not None:
  172. data_output2.schema = {"header": ["predicted_cluster_index", "distance", "type"],
  173. "sid": "id",
  174. "content_type": "cluster_result"}
  175. predict_datas = [data_output1, data_output2]
  176. else:
  177. data_output = predict_data
  178. if data_output is not None:
  179. data_output.schema = {"header": ["label", "predicted_label", "type"],
  180. "sid": schema.get('sid'),
  181. "content_type": "cluster_result"}
  182. if "match_id_name" in schema:
  183. data_output.schema["match_id_name"] = schema["match_id_name"]
  184. predict_datas = [data_output, None]
  185. return predict_datas