utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import os
  2. import json
  3. import socket
  4. import time
  5. import typing
  6. import tarfile
  7. import datetime
  8. from enum import Enum, IntEnum
  9. PROJECT_BASE = os.getenv("FATE_DEPLOY_BASE")
  10. def start_cluster_standalone_job_server():
  11. print("use service.sh to start standalone node server....")
  12. os.system("sh service.sh start --standalone_node")
  13. time.sleep(5)
  14. def get_parser_version_set():
  15. return {"1", "2"}
  16. def get_project_base_directory():
  17. global PROJECT_BASE
  18. if PROJECT_BASE is None:
  19. PROJECT_BASE = os.path.abspath(
  20. os.path.join(
  21. os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir
  22. )
  23. )
  24. return PROJECT_BASE
  25. def download_from_request(http_response, tar_file_name, extract_dir):
  26. with open(tar_file_name, "wb") as fw:
  27. for chunk in http_response.iter_content(1024):
  28. if chunk:
  29. fw.write(chunk)
  30. tar = tarfile.open(tar_file_name, "r:gz")
  31. file_names = tar.getnames()
  32. for file_name in file_names:
  33. tar.extract(file_name, extract_dir)
  34. tar.close()
  35. os.remove(tar_file_name)
  36. def check_config(config: typing.Dict, required_arguments: typing.List):
  37. no_arguments = []
  38. error_arguments = []
  39. for require_argument in required_arguments:
  40. if isinstance(require_argument, tuple):
  41. config_value = config.get(require_argument[0], None)
  42. if isinstance(require_argument[1], (tuple, list)):
  43. if config_value not in require_argument[1]:
  44. error_arguments.append(require_argument)
  45. elif config_value != require_argument[1]:
  46. error_arguments.append(require_argument)
  47. elif require_argument not in config:
  48. no_arguments.append(require_argument)
  49. if no_arguments or error_arguments:
  50. raise Exception(
  51. "the following arguments are required: {} {}".format(
  52. ",".join(no_arguments),
  53. ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]),
  54. )
  55. )
  56. def preprocess(**kwargs):
  57. kwargs.pop('self', None)
  58. kwargs.pop('kwargs', None)
  59. config_data = kwargs.pop('config_data', {})
  60. dsl_data = kwargs.pop('dsl_data', {})
  61. output_path = kwargs.pop('output_path', None)
  62. if output_path is not None:
  63. config_data['output_path'] = os.path.abspath(output_path)
  64. local = config_data.pop('local', {})
  65. party_id = kwargs.pop('party_id', None)
  66. role = kwargs.pop('role', None)
  67. if party_id is not None:
  68. kwargs['party_id'] = local['party_id'] = int(party_id)
  69. if role is not None:
  70. kwargs['role'] = local['role'] = role
  71. if local:
  72. config_data['local'] = local
  73. for k, v in kwargs.items():
  74. if v is not None:
  75. if k in {'job_id', 'model_version'}:
  76. v = str(v)
  77. elif k in {'party_id', 'step_index'}:
  78. v = int(v)
  79. config_data[k] = v
  80. return config_data, dsl_data
  81. def check_output_path(path):
  82. if not os.path.isabs(path):
  83. return os.path.join(os.path.abspath(os.curdir), path)
  84. return path
  85. def string_to_bytes(string):
  86. return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
  87. def get_lan_ip():
  88. if os.name != "nt":
  89. import fcntl
  90. import struct
  91. def get_interface_ip(ifname):
  92. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  93. return socket.inet_ntoa(
  94. fcntl.ioctl(
  95. s.fileno(),
  96. 0x8915,
  97. struct.pack("256s", string_to_bytes(ifname[:15])),
  98. )[20:24]
  99. )
  100. ip = socket.gethostbyname(socket.getfqdn())
  101. if ip.startswith("127.") and os.name != "nt":
  102. interfaces = [
  103. "bond1",
  104. "eth0",
  105. "eth1",
  106. "eth2",
  107. "wlan0",
  108. "wlan1",
  109. "wifi0",
  110. "ath0",
  111. "ath1",
  112. "ppp0",
  113. ]
  114. for ifname in interfaces:
  115. try:
  116. ip = get_interface_ip(ifname)
  117. break
  118. except IOError as e:
  119. pass
  120. return ip or ""
  121. class CustomJSONEncoder(json.JSONEncoder):
  122. def __init__(self, **kwargs):
  123. super(CustomJSONEncoder, self).__init__(**kwargs)
  124. def default(self, obj):
  125. if isinstance(obj, datetime.datetime):
  126. return obj.strftime("%Y-%m-%d %H:%M:%S")
  127. elif isinstance(obj, datetime.date):
  128. return obj.strftime("%Y-%m-%d")
  129. elif isinstance(obj, datetime.timedelta):
  130. return str(obj)
  131. elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
  132. return obj.value
  133. else:
  134. return json.JSONEncoder.default(self, obj)
  135. def json_dumps(src, byte=False, indent=None):
  136. if byte:
  137. return string_to_bytes(json.dumps(src, indent=indent, cls=CustomJSONEncoder))
  138. else:
  139. return json.dumps(src, indent=indent, cls=CustomJSONEncoder)