tools.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import json
  17. import typing
  18. from pathlib import Path
  19. from ruamel import yaml
  20. def merge_dict(dict1, dict2):
  21. merge_ret = {}
  22. keyset = dict1.keys() | dict2.keys()
  23. for key in keyset:
  24. if key in dict1 and key in dict2:
  25. val1 = dict1.get(key)
  26. val2 = dict2.get(key)
  27. assert type(val1).__name__ == type(val2).__name__
  28. if isinstance(val1, dict):
  29. merge_ret[key] = merge_dict(val1, val2)
  30. else:
  31. merge_ret[key] = val2
  32. elif key in dict1:
  33. merge_ret[key] = dict1.get(key)
  34. else:
  35. merge_ret[key] = dict2.get(key)
  36. return merge_ret
  37. def extract_explicit_parameter(func):
  38. def wrapper(*args, **kwargs):
  39. explict_kwargs = {"explict_parameters": kwargs}
  40. return func(*args, **explict_kwargs)
  41. return wrapper
  42. def load_job_config(path):
  43. config = JobConfig.load(path)
  44. return config
  45. class Parties(object):
  46. def __init__(self, parties):
  47. self.host = parties.get("host", None)
  48. self.guest = parties.get("guest", None)
  49. self.arbiter = parties.get("arbiter", None)
  50. class JobConfig(object):
  51. def __init__(self, config):
  52. self.parties = Parties(config.get("parties", {}))
  53. self.backend = config.get("backend", 0)
  54. self.work_mode = config.get("work_mode", 0)
  55. self.data_base_dir = config.get("data_base_dir", "")
  56. self.system_setting = config.get("system_setting", {})
  57. @staticmethod
  58. def load(path: typing.Union[str, Path]):
  59. conf = JobConfig.load_from_file(path)
  60. return JobConfig(conf)
  61. @staticmethod
  62. def load_from_file(path: typing.Union[str, Path]):
  63. """
  64. Loads conf content from json or yaml file. Used to read in parameter configuration
  65. Parameters
  66. ----------
  67. path: str, path to conf file, should be absolute path
  68. Returns
  69. -------
  70. dict, parameter configuration in dictionary format
  71. """
  72. if isinstance(path, str):
  73. path = Path(path)
  74. config = {}
  75. if path is not None:
  76. file_type = path.suffix
  77. with path.open("r") as f:
  78. if file_type == ".yaml":
  79. config.update(yaml.safe_load(f))
  80. elif file_type == ".json":
  81. config.update(json.load(f))
  82. else:
  83. raise ValueError(f"Cannot load conf from file type {file_type}")
  84. return config