step.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import numpy as np
  18. from federatedml.statistic.data_overview import get_header, get_anonymous_header
  19. from federatedml.util import consts
  20. from federatedml.util import LOGGER
  21. from federatedml.util.data_transform import set_schema
  22. class Step(object):
  23. def __init__(self):
  24. self.feature_list = []
  25. self.step_direction = ""
  26. self.n_step = 0
  27. self.n_model = 0
  28. def set_step_info(self, step_info):
  29. n_step, n_model = step_info
  30. self.n_step = n_step
  31. self.n_model = n_model
  32. def get_flowid(self):
  33. flowid = "train.step{}.model{}".format(self.n_step, self.n_model)
  34. return flowid
  35. @staticmethod
  36. def slice_data_instance(data_instance, feature_mask):
  37. """
  38. return data_instance with features at given indices
  39. Parameters
  40. ----------
  41. data_instance: data Instance object, input data
  42. feature_mask: mask to filter data_instance
  43. """
  44. data_instance.features = data_instance.features[feature_mask]
  45. return data_instance
  46. @staticmethod
  47. def get_new_schema(original_data, feature_mask):
  48. schema = copy.deepcopy(original_data.schema)
  49. old_header = get_header(original_data)
  50. new_header = [old_header[i] for i in np.where(feature_mask > 0)[0]]
  51. schema["header"] = new_header
  52. old_anonymous_header = get_anonymous_header(original_data)
  53. if old_anonymous_header:
  54. new_anonymous_header = [old_anonymous_header[i] for i in np.where(feature_mask > 0)[0]]
  55. schema["anonymous_header"] = new_anonymous_header
  56. LOGGER.debug(f"given feature_mask: {feature_mask}, new anonymous header is: {new_anonymous_header}")
  57. return schema
  58. def run(self, original_model, train_data, validate_data, feature_mask):
  59. model = copy.deepcopy(original_model)
  60. current_flowid = self.get_flowid()
  61. model.set_flowid(current_flowid)
  62. if original_model.role != consts.ARBITER:
  63. curr_train_data = train_data.mapValues(lambda v: Step.slice_data_instance(v, feature_mask))
  64. new_schema = Step.get_new_schema(train_data, feature_mask)
  65. # LOGGER.debug("new schema is: {}".format(new_schema))
  66. set_schema(curr_train_data, new_schema)
  67. model.header = new_schema.get("header")
  68. else:
  69. curr_train_data = train_data
  70. model.fit(curr_train_data)
  71. return model