hetero_pearson.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import numpy as np
  17. from fate_arch.common import Party
  18. from federatedml.model_base import MetricMeta, ModelBase
  19. from federatedml.param.pearson_param import PearsonParam
  20. from federatedml.secureprotol.spdz import SPDZ
  21. from federatedml.secureprotol.spdz.tensor.fixedpoint_table import (
  22. FixedPointTensor,
  23. table_dot,
  24. )
  25. from federatedml.statistic.data_overview import get_anonymous_header, get_header
  26. from federatedml.transfer_variable.base_transfer_variable import BaseTransferVariables
  27. from federatedml.util import LOGGER
  28. class PearsonTransferVariable(BaseTransferVariables):
  29. def __init__(self, flowid=0):
  30. super().__init__(flowid)
  31. self.anonymous_host = self._create_variable(
  32. "anonymous_host", src=["host"], dst=["guest"]
  33. )
  34. self.anonymous_guest = self._create_variable(
  35. "anonymous_guest", src=["guest"], dst=["host"]
  36. )
  37. class HeteroPearson(ModelBase):
  38. def __init__(self):
  39. super().__init__()
  40. self.model_param = PearsonParam()
  41. self.transfer_variable = PearsonTransferVariable()
  42. self._summary = {}
  43. self._modelsaver = PearsonModelSaver()
  44. def fit(self, data_instance):
  45. LOGGER.info("fit start")
  46. column_names = get_header(data_instance)
  47. column_anonymous_names = get_anonymous_header(data_instance)
  48. self._modelsaver.save_local_anonymous(column_names, column_anonymous_names)
  49. parties = [
  50. Party("guest", self.component_properties.guest_partyid),
  51. Party("host", self.component_properties.host_party_idlist[0]),
  52. ]
  53. local_party = parties[0] if self.is_guest else parties[1]
  54. other_party = parties[1] if self.is_guest else parties[0]
  55. self._modelsaver.save_party(local_party)
  56. LOGGER.info("select features")
  57. names, selected_features = select_columns(
  58. data_instance,
  59. self.model_param.column_indexes,
  60. self.model_param.column_names,
  61. )
  62. LOGGER.info("standardized feature data")
  63. num_data, standardized, remainds_indexes, num_features = standardize(
  64. selected_features
  65. )
  66. self._summary["num_local_features"] = num_features
  67. # local corr
  68. LOGGER.info("calculate correlation cross local features")
  69. local_corr = table_dot(standardized, standardized) / num_data
  70. fixed_local_corr = fix_local_corr(local_corr, remainds_indexes, num_features)
  71. self._modelsaver.save_local_corr(fixed_local_corr)
  72. self._summary["local_corr"] = fixed_local_corr.tolist()
  73. shape = fixed_local_corr.shape[0]
  74. # local vif
  75. if self.model_param.calc_local_vif:
  76. LOGGER.info("calc_local_vif enabled, calculate vif for local features")
  77. local_vif = vif_from_pearson_matrix(local_corr)
  78. fixed_local_vif = fix_vif(local_vif, remainds_indexes, num_features)
  79. self._modelsaver.save_local_vif(fixed_local_vif)
  80. else:
  81. LOGGER.info("calc_local_vif disabled, skip local vif")
  82. # not cross parties
  83. if not self.model_param.cross_parties:
  84. LOGGER.info("cross_parties disabled, save model")
  85. self._modelsaver.save_party_info(shape, local_party, names)
  86. # cross parties
  87. else:
  88. LOGGER.info(
  89. "cross_parties enabled, calculating correlation with remote features"
  90. )
  91. # sync anonymous
  92. LOGGER.info("sync anonymous names")
  93. remote_anonymous_names, remote_remainds_indexes = self.sync_anonymous_names(
  94. column_anonymous_names, remainds_indexes
  95. )
  96. if self.is_guest:
  97. names = [column_names, remote_anonymous_names]
  98. remainds_indexes_tuple = (remainds_indexes, remote_remainds_indexes)
  99. else:
  100. names = [remote_anonymous_names, column_names]
  101. remainds_indexes_tuple = (remote_remainds_indexes, remainds_indexes)
  102. m1, m2 = len(names[0]), len(names[1])
  103. shapes = [m1, m2]
  104. for shape, party, name in zip(shapes, parties, names):
  105. self._modelsaver.save_party_info(shape, party, name)
  106. self._summary["num_remote_features"] = m2 if self.is_guest else m1
  107. with SPDZ(
  108. "pearson",
  109. local_party=local_party,
  110. all_parties=parties,
  111. use_mix_rand=self.model_param.use_mix_rand,
  112. ) as spdz:
  113. LOGGER.info("secret share: prepare data")
  114. if self.is_guest:
  115. x, y = (
  116. FixedPointTensor.from_source("x", standardized),
  117. FixedPointTensor.from_source("y", other_party),
  118. )
  119. else:
  120. y, x = (
  121. FixedPointTensor.from_source("y", standardized),
  122. FixedPointTensor.from_source("x", other_party),
  123. )
  124. LOGGER.info("secret share: dot")
  125. corr = spdz.dot(x, y, "corr").get() / num_data
  126. fixed_corr = fix_corr(
  127. corr, m1, m2, remainds_indexes_tuple[0], remainds_indexes_tuple[1]
  128. )
  129. self._modelsaver.save_cross_corr(fixed_corr)
  130. self._summary["corr"] = fixed_corr.tolist()
  131. self._callback()
  132. self.set_summary(self._summary)
  133. LOGGER.info("fit done")
  134. @property
  135. def is_guest(self):
  136. return self.component_properties.role == "guest"
  137. def _init_model(self, param):
  138. super()._init_model(param)
  139. self.model_param = param
  140. def export_model(self):
  141. return self._modelsaver.export()
  142. # noinspection PyTypeChecker
  143. def _callback(self):
  144. self.tracker.set_metric_meta(
  145. metric_namespace="statistic",
  146. metric_name="correlation",
  147. metric_meta=MetricMeta(name="pearson", metric_type="CORRELATION_GRAPH"),
  148. )
  149. def sync_anonymous_names(self, local_anonymous, remainds_indexes):
  150. if self.is_guest:
  151. self.transfer_variable.anonymous_guest.remote(
  152. (local_anonymous, remainds_indexes), role="host"
  153. )
  154. (
  155. remote_anonymous,
  156. remote_remainds_indexes,
  157. ) = self.transfer_variable.anonymous_host.get(role="host", idx=0)
  158. else:
  159. self.transfer_variable.anonymous_host.remote(
  160. (local_anonymous, remainds_indexes), role="guest"
  161. )
  162. (
  163. remote_anonymous,
  164. remote_remainds_indexes,
  165. ) = self.transfer_variable.anonymous_guest.get(role="guest", idx=0)
  166. return remote_anonymous, remote_remainds_indexes
  167. class PearsonModelSaver:
  168. def __init__(self) -> None:
  169. from federatedml.protobuf.generated import (
  170. pearson_model_meta_pb2,
  171. pearson_model_param_pb2,
  172. )
  173. self.meta_pb = pearson_model_meta_pb2.PearsonModelMeta()
  174. self.param_pb = pearson_model_param_pb2.PearsonModelParam()
  175. self.param_pb.model_name = "HeteroPearson"
  176. def export(self):
  177. MODEL_META_NAME = "HeteroPearsonModelMeta"
  178. MODEL_PARAM_NAME = "HeteroPearsonModelParam"
  179. return {MODEL_META_NAME: self.meta_pb, MODEL_PARAM_NAME: self.param_pb}
  180. def save_shapes(self, shapes):
  181. for shape in shapes:
  182. self.meta_pb.shapes.append(shape)
  183. def save_local_corr(self, corr):
  184. self.param_pb.shape = corr.shape[0]
  185. for v in corr.reshape(-1):
  186. self.param_pb.local_corr.append(v.tolist())
  187. def save_party_info(self, shape, party, names):
  188. self.param_pb.shapes.append(shape)
  189. self.param_pb.parties.append(f"({party.role},{party.party_id})")
  190. _names = self.param_pb.all_names.add()
  191. for name in names:
  192. _names.names.append(name)
  193. def save_local_vif(self, local_vif):
  194. for vif_value in local_vif:
  195. self.param_pb.local_vif.append(vif_value)
  196. def save_cross_corr(self, corr):
  197. for v in corr.reshape(-1):
  198. self.param_pb.corr.append(v.tolist())
  199. def save_party(self, party):
  200. self.param_pb.party = f"({party.role},{party.party_id})"
  201. def save_local_anonymous(self, names, anonymous_names):
  202. for name, anonymous_name in zip(names, anonymous_names):
  203. self.param_pb.names.append(name)
  204. anonymous = self.param_pb.anonymous_map.add()
  205. anonymous.name = name
  206. anonymous.anonymous = anonymous_name
  207. def standardize(data):
  208. """
  209. x -> (x - mu) / sigma
  210. """
  211. n = data.count()
  212. sum_x, sum_square_x = data.mapValues(lambda x: (x, x ** 2)).reduce(
  213. lambda pair1, pair2: (pair1[0] + pair2[0], pair1[1] + pair2[1])
  214. )
  215. mu = sum_x / n
  216. sigma = np.sqrt(sum_square_x / n - mu ** 2)
  217. size = len(sigma)
  218. remiands_indexes = [i for i, e in enumerate(sigma) if e > 0]
  219. if len(remiands_indexes) < size:
  220. LOGGER.warning(
  221. f"zero standard deviation detected, sigma={sigma}, zeroindexes={np.argwhere(sigma)}"
  222. )
  223. return (
  224. n,
  225. data.mapValues(
  226. lambda x: (x[remiands_indexes] - mu[remiands_indexes])
  227. / sigma[remiands_indexes]
  228. ),
  229. remiands_indexes,
  230. size,
  231. )
  232. return n, data.mapValues(lambda x: (x - mu) / sigma), remiands_indexes, size
  233. def select_columns(data_instance, hit_column_indexes, hit_column_names):
  234. """
  235. select features
  236. """
  237. column_names = data_instance.schema["header"]
  238. num_columns = len(column_names)
  239. # accept all features
  240. if hit_column_indexes == -1:
  241. if len(hit_column_names) > 0:
  242. raise ValueError(f"specify column name when column_indexes=-1 is ambiguity")
  243. return column_names, data_instance.mapValues(lambda inst: inst.features)
  244. # check hit column indexes and column names
  245. name_to_index = {c: i for i, c in enumerate(column_names)}
  246. selected = set()
  247. for name in hit_column_names:
  248. if name not in name_to_index:
  249. raise ValueError(f"feature name `{name}` not found in data schema")
  250. else:
  251. selected.add(name_to_index[name])
  252. for idx in hit_column_indexes:
  253. if 0 <= idx < num_columns:
  254. selected.add(idx)
  255. else:
  256. raise ValueError(f"feature idx={idx} out of bound")
  257. selected = sorted(list(selected))
  258. # take shortcut if all feature hit
  259. if len(selected) == len(column_names):
  260. return column_names, data_instance.mapValues(lambda inst: inst.features)
  261. return (
  262. [column_names[i] for i in selected],
  263. data_instance.mapValues(lambda inst: inst.features[selected]),
  264. )
  265. def vif_from_pearson_matrix(pearson_matrix, threshold=1e-8):
  266. LOGGER.info(f"local vif calc: start")
  267. assert not np.isnan(
  268. pearson_matrix
  269. ).any(), f"should not contains nan: {pearson_matrix}"
  270. N = pearson_matrix.shape[0]
  271. vif = []
  272. LOGGER.info(f"local vif calc: calc matrix eigvals")
  273. eig = sorted([abs(v) for v in np.linalg.eigvalsh(pearson_matrix)])
  274. num_drop = len(list(filter(lambda x: x < threshold, eig)))
  275. det_non_zero = np.prod(eig[num_drop:])
  276. LOGGER.info(f"local vif calc: calc submatrix eigvals")
  277. for i in range(N):
  278. indexes = [j for j in range(N) if j != i]
  279. cofactor_matrix = pearson_matrix[indexes][:, indexes]
  280. cofactor_eig = sorted([abs(v) for v in np.linalg.eigvalsh(cofactor_matrix)])
  281. vif.append(np.prod(cofactor_eig[num_drop:]) / det_non_zero)
  282. LOGGER.info(f"local vif calc: submatrix {i+1}/{N} eig is {vif[-1]}")
  283. LOGGER.info(f"local vif calc done")
  284. return vif
  285. def fix_local_corr(remaind_corr, remainds_indexes, size):
  286. corr = np.zeros((size, size))
  287. corr.fill(np.nan)
  288. corr[np.ix_(remainds_indexes, remainds_indexes)] = np.clip(remaind_corr, -1.0, 1.0)
  289. return corr
  290. def fix_vif(remains_vif, remainds_indexes, size):
  291. vif = np.zeros(size)
  292. vif.fill(np.nan)
  293. vif[remainds_indexes] = remains_vif
  294. return vif
  295. def fix_corr(remaind_corr, m1, m2, remainds_indexes1, remainds_indexes2):
  296. corr = np.zeros((m1, m2))
  297. corr.fill(np.nan)
  298. corr[np.ix_(remainds_indexes1, remainds_indexes2)] = np.clip(
  299. remaind_corr, -1.0, 1.0
  300. )
  301. return corr