version_controller.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # algorithm version compatibility control
  17. from fate_arch.common import file_utils
  18. from fate_flow.settings import INCOMPATIBLE_VERSION_CONF
  19. class VersionController:
  20. INCOMPATIBLE_VERSION = {}
  21. @classmethod
  22. def init(cls):
  23. try:
  24. conf = file_utils.load_yaml_conf(INCOMPATIBLE_VERSION_CONF)
  25. for key, key_version in conf.items():
  26. cls.INCOMPATIBLE_VERSION[key] = {}
  27. for version in conf[key]:
  28. cls.INCOMPATIBLE_VERSION[key][str(version)] = conf[key][version]
  29. except Exception as e:
  30. pass
  31. @classmethod
  32. def job_provider_version_check(cls, providers_info, local_role, local_party_id):
  33. incompatible_info = {}
  34. incompatible = False
  35. if local_role in providers_info:
  36. local_provider = providers_info[local_role].get(int(local_party_id), {}) \
  37. or providers_info[local_role].get(str(local_party_id), {})
  38. for role, role_provider in providers_info.items():
  39. incompatible_info[role] = {}
  40. for party_id, provider in role_provider.items():
  41. if role == local_role and str(party_id) == str(local_party_id):
  42. continue
  43. role_incompatible_info = cls.provider_version_check(local_provider, party_provider=provider)
  44. if role_incompatible_info:
  45. incompatible = True
  46. incompatible_info[role][party_id] = role_incompatible_info
  47. if incompatible:
  48. raise ValueError(f"version compatibility check failed: {incompatible_info}")
  49. @classmethod
  50. def provider_version_check(cls, local_provider, party_provider):
  51. incompatible_info = {}
  52. for component, info in local_provider.items():
  53. if party_provider.get(component):
  54. local_version = local_provider.get(component).get("provider").get("version")
  55. party_version = party_provider.get(component).get("provider").get("version")
  56. if cls.is_incompatible(local_version, party_version):
  57. if component in incompatible_info:
  58. incompatible_info[component].append((local_version, party_version))
  59. else:
  60. incompatible_info[component] = [(local_version, party_version)]
  61. return incompatible_info
  62. @classmethod
  63. def is_incompatible(cls, source_version, dest_version, key="FATE"):
  64. if not source_version or not dest_version:
  65. return False
  66. index = len(source_version)
  67. while True:
  68. if source_version[:index] in cls.INCOMPATIBLE_VERSION.get(key, {}).keys():
  69. for incompatible_value in cls.INCOMPATIBLE_VERSION.get(key)[source_version[:index]].split(","):
  70. if cls.is_match(dest_version, incompatible_value.strip()):
  71. return True
  72. index -= 1
  73. if index == 0:
  74. return False
  75. @classmethod
  76. def is_match(cls, dest_ver, incompatible_value):
  77. symbols, incompatible_ver = cls.extract_symbols(incompatible_value)
  78. dest_ver_list = cls.extend_version([int(_) for _ in dest_ver.split(".")])
  79. incompatible_ver_list = cls.extend_version([int(_) for _ in incompatible_ver.split(".")])
  80. print(dest_ver_list, incompatible_ver_list, symbols)
  81. for index in range(4):
  82. if dest_ver_list[index] == incompatible_ver_list[index]:
  83. continue
  84. if dest_ver_list[index] > incompatible_ver_list[index]:
  85. return True if ">" in symbols else False
  86. if dest_ver_list[index] < incompatible_ver_list[index]:
  87. return True if "<" in symbols else False
  88. return True if "=" in symbols else False
  89. @classmethod
  90. def extend_version(cls, v):
  91. v_len = len(v)
  92. if v_len < 4:
  93. for i in range(4 - v_len):
  94. v.append(0)
  95. return v
  96. @classmethod
  97. def extract_symbols(cls, incompatible_value):
  98. symbols_list = ["<", ">", "="]
  99. index = 0
  100. for index, ver in enumerate(incompatible_value):
  101. if ver not in symbols_list:
  102. break
  103. symbol = incompatible_value[0: index]
  104. if not incompatible_value[0: index]:
  105. symbol = "="
  106. return symbol, incompatible_value[index:]