isometric_model.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. import numpy as np
  19. from federatedml.util import LOGGER
  20. class SingleMetricInfo(object):
  21. """
  22. Use to Store Metric values
  23. Parameters
  24. ----------
  25. values: ndarray or list
  26. List of metric value of each column. Do not accept missing value.
  27. col_names: list
  28. List of column_names of above list whose length should match with above values.
  29. host_party_ids: list of int (party_id, such as 9999)
  30. If it is a federated metric, list of host party ids
  31. host_values: list of ndarray
  32. The outer list specify each host's values. The inner list are the values of
  33. this party
  34. host_col_names: list of list
  35. Similar to host_values where the content is col_names
  36. """
  37. def __init__(self, values, col_names, host_party_ids=None,
  38. host_values=None, host_col_names=None):
  39. if host_party_ids is None:
  40. host_party_ids = []
  41. if host_values is None:
  42. host_values = []
  43. if host_col_names is None:
  44. host_col_names = []
  45. self.values = values
  46. self.col_names = col_names
  47. self.host_party_ids = host_party_ids
  48. self.host_values = host_values
  49. self.host_col_names = host_col_names
  50. self.check()
  51. def check(self):
  52. if len(self.values) != len(self.col_names):
  53. raise ValueError("When creating SingleMetricValue, length of values "
  54. "and length of col_names should be equal")
  55. if not (len(self.host_party_ids) == len(self.host_values) == len(self.host_col_names)):
  56. raise ValueError("When creating SingleMetricValue, length of values "
  57. "and length of col_names and host_party_ids should be equal")
  58. def union_result(self):
  59. values = list(self.values)
  60. col_names = [("guest", x) for x in self.col_names]
  61. for idx, host_id in enumerate(self.host_party_ids):
  62. values.extend(self.host_values[idx])
  63. col_names.extend([(host_id, x) for x in self.host_col_names[idx]])
  64. if len(values) != len(col_names):
  65. raise AssertionError("union values and col_names should have same length")
  66. values = np.array(values)
  67. return values, col_names
  68. def get_values(self):
  69. return copy.deepcopy(self.values)
  70. def get_col_names(self):
  71. return copy.deepcopy(self.col_names)
  72. def get_partial_values(self, select_col_names, party_id=None):
  73. """
  74. Return values selected by provided col_names.
  75. Use party_id to indicate which party to get. If None, obtain from values,
  76. otherwise, obtain from host_values
  77. """
  78. if party_id is None:
  79. col_name_map = {name: idx for idx, name in enumerate(self.col_names)}
  80. col_indices = [col_name_map[x] for x in select_col_names]
  81. values = np.array(self.values)[col_indices]
  82. else:
  83. if party_id not in self.host_party_ids:
  84. raise ValueError(f"party_id: {party_id} is not in host_party_ids:"
  85. f" {self.host_party_ids}")
  86. party_idx = self.host_party_ids.index(party_id)
  87. col_name_map = {name: idx for idx, name in
  88. enumerate(self.host_col_names[party_idx])}
  89. # LOGGER.debug(f"col_name_map: {col_name_map}")
  90. values = []
  91. host_values = np.array(self.host_values[party_idx])
  92. for host_col_name in select_col_names:
  93. if host_col_name in col_name_map:
  94. values.append(host_values[col_name_map[host_col_name]])
  95. else:
  96. values.append(0)
  97. # col_indices = [col_name_map[x] for x in select_col_names]
  98. # values = np.array(self.host_values[party_idx])[col_indices]
  99. return list(values)
  100. class IsometricModel(object):
  101. """
  102. Use to Store Metric values
  103. Parameters
  104. ----------
  105. metric_name: list of str
  106. The metric name, eg. iv. If a single string
  107. metric_info: list of SingleMetricInfo
  108. """
  109. def __init__(self, metric_name=None, metric_info=None):
  110. if metric_name is None:
  111. metric_name = []
  112. if not isinstance(metric_name, list):
  113. metric_name = [metric_name]
  114. if metric_info is None:
  115. metric_info = []
  116. if not isinstance(metric_info, list):
  117. metric_info = [metric_info]
  118. self._metric_names = metric_name
  119. self._metric_info = metric_info
  120. def add_metric_value(self, metric_name, metric_info):
  121. self._metric_names.append(metric_name)
  122. self._metric_info.append(metric_info)
  123. @property
  124. def valid_value_name(self):
  125. return self._metric_names
  126. def get_metric_info(self, metric_name):
  127. LOGGER.debug(f"valid_value_name: {self.valid_value_name}, "
  128. f"metric_name: {metric_name}")
  129. if metric_name not in self.valid_value_name:
  130. return None
  131. return self._metric_info[self._metric_names.index(metric_name)]
  132. def get_all_metric_info(self):
  133. return self._metric_info