classify_label_checker.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. ################################################################################
  19. #
  20. #
  21. ################################################################################
  22. # =============================================================================
  23. # Lable Checker
  24. # =============================================================================
  25. from federatedml.util import consts
  26. class ClassifyLabelChecker(object):
  27. def __init__(self):
  28. pass
  29. @staticmethod
  30. def validate_label(data_inst):
  31. """
  32. Label Checker in classification task.
  33. Check whether the distinct labels is no more than MAX_CLASSNUM which define in consts,
  34. also get all distinct labels
  35. Parameters
  36. ----------
  37. data_inst : Table,
  38. values are data instance format define in federatedml/feature/instance.py
  39. Returns
  40. -------
  41. num_class : int, the number of distinct labels
  42. labels : list, the distince labels
  43. """
  44. class_set = data_inst.applyPartitions(ClassifyLabelChecker.get_all_class).reduce(lambda x, y: x | y)
  45. num_class = len(class_set)
  46. if len(class_set) > consts.MAX_CLASSNUM:
  47. raise ValueError("In Classfy Proble, max dif classes should no more than %d" % (consts.MAX_CLASSNUM))
  48. return num_class, list(class_set)
  49. @staticmethod
  50. def get_all_class(kv_iterator):
  51. class_set = set()
  52. for _, inst in kv_iterator:
  53. class_set.add(inst.label)
  54. if len(class_set) > consts.MAX_CLASSNUM:
  55. raise ValueError("In Classify Task, max dif classes should no more than %d" % (consts.MAX_CLASSNUM))
  56. return class_set
  57. class RegressionLabelChecker(object):
  58. @staticmethod
  59. def validate_label(data_inst):
  60. """
  61. Label Checker in regression task.
  62. Check if all labels is a float type.
  63. Parameters
  64. ----------
  65. data_inst : Table,
  66. values are data instance format define in federatedml/feature/instance.py
  67. """
  68. data_inst.mapValues(RegressionLabelChecker.test_numeric_data)
  69. @staticmethod
  70. def test_numeric_data(value):
  71. try:
  72. label = float(value.label)
  73. except BaseException:
  74. raise ValueError("In Regression Task, all label should be numeric!!")