123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import json
- import typing
- from pathlib import Path
- from ruamel import yaml
- def merge_dict(dict1, dict2):
- merge_ret = {}
- keyset = dict1.keys() | dict2.keys()
- for key in keyset:
- if key in dict1 and key in dict2:
- val1 = dict1.get(key)
- val2 = dict2.get(key)
- assert type(val1).__name__ == type(val2).__name__
- if isinstance(val1, dict):
- merge_ret[key] = merge_dict(val1, val2)
- else:
- merge_ret[key] = val2
- elif key in dict1:
- merge_ret[key] = dict1.get(key)
- else:
- merge_ret[key] = dict2.get(key)
- return merge_ret
- def extract_explicit_parameter(func):
- def wrapper(*args, **kwargs):
- explict_kwargs = {"explict_parameters": kwargs}
- return func(*args, **explict_kwargs)
- return wrapper
- def load_job_config(path):
- config = JobConfig.load(path)
- return config
- class Parties(object):
- def __init__(self, parties):
- self.host = parties.get("host", None)
- self.guest = parties.get("guest", None)
- self.arbiter = parties.get("arbiter", None)
- class JobConfig(object):
- def __init__(self, config):
- self.parties = Parties(config.get("parties", {}))
- self.backend = config.get("backend", 0)
- self.work_mode = config.get("work_mode", 0)
- self.data_base_dir = config.get("data_base_dir", "")
- self.system_setting = config.get("system_setting", {})
- @staticmethod
- def load(path: typing.Union[str, Path]):
- conf = JobConfig.load_from_file(path)
- return JobConfig(conf)
- @staticmethod
- def load_from_file(path: typing.Union[str, Path]):
- """
- Loads conf content from json or yaml file. Used to read in parameter configuration
- Parameters
- ----------
- path: str, path to conf file, should be absolute path
- Returns
- -------
- dict, parameter configuration in dictionary format
- """
- if isinstance(path, str):
- path = Path(path)
- config = {}
- if path is not None:
- file_type = path.suffix
- with path.open("r") as f:
- if file_type == ".yaml":
- config.update(yaml.safe_load(f))
- elif file_type == ".json":
- config.update(json.load(f))
- else:
- raise ValueError(f"Cannot load conf from file type {file_type}")
- return config
|