__init__.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import inspect
  18. from functools import wraps
  19. from pathlib import Path
  20. from filelock import FileLock as _FileLock
  21. from fate_arch.protobuf.python.default_empty_fill_pb2 import DefaultEmptyFillMessage
  22. from fate_flow.component_env_utils import provider_utils
  23. from fate_flow.db.runtime_config import RuntimeConfig
  24. from fate_flow.settings import stat_logger
  25. def serialize_buffer_object(buffer_object):
  26. # the type is bytes, not str
  27. serialized_string = buffer_object.SerializeToString()
  28. if not serialized_string:
  29. fill_message = DefaultEmptyFillMessage()
  30. fill_message.flag = 'set'
  31. serialized_string = fill_message.SerializeToString()
  32. return serialized_string
  33. def get_proto_buffer_class(buffer_name):
  34. module_path, base_import_path = provider_utils.get_provider_model_paths(RuntimeConfig.COMPONENT_PROVIDER)
  35. exception = ModuleNotFoundError(f'no module named {buffer_name}')
  36. for f in module_path.glob('*.py'):
  37. try:
  38. proto_module = importlib.import_module('.'.join([*base_import_path, f.stem]))
  39. for name, obj in inspect.getmembers(proto_module):
  40. if inspect.isclass(obj) and name == buffer_name:
  41. return obj
  42. except Exception as e:
  43. exception = e
  44. stat_logger.warning(e)
  45. raise exception
  46. def parse_proto_object(buffer_name, serialized_string, buffer_class=None):
  47. try:
  48. if buffer_class is None:
  49. buffer_class = get_proto_buffer_class(buffer_name)
  50. buffer_object = buffer_class()
  51. except Exception as e:
  52. stat_logger.exception('Can not restore proto buffer object')
  53. raise e
  54. buffer_name = type(buffer_object).__name__
  55. try:
  56. buffer_object.ParseFromString(serialized_string)
  57. except Exception as e1:
  58. stat_logger.exception(e1)
  59. try:
  60. DefaultEmptyFillMessage().ParseFromString(serialized_string)
  61. buffer_object.ParseFromString(bytes())
  62. except Exception as e2:
  63. stat_logger.exception(e2)
  64. raise e1
  65. else:
  66. stat_logger.info(f'parse {buffer_name} proto object with default values')
  67. else:
  68. stat_logger.info(f'parse {buffer_name} proto object normal')
  69. return buffer_object
  70. def lock(method):
  71. @wraps(method)
  72. def magic(self, *args, **kwargs):
  73. with self.lock:
  74. return method(self, *args, **kwargs)
  75. return magic
  76. def local_cache_required(locking=False):
  77. def decorator(method):
  78. @wraps(method)
  79. def magic(self, *args, **kwargs):
  80. if not self.exists():
  81. raise FileNotFoundError(f'Can not found {self.model_id} {self.model_version} model local cache')
  82. if not locking:
  83. return method(self, *args, **kwargs)
  84. with self.lock:
  85. return method(self, *args, **kwargs)
  86. return magic
  87. return decorator
  88. class Locker:
  89. def __init__(self, directory):
  90. if isinstance(directory, str):
  91. directory = Path(directory)
  92. self.directory = directory
  93. self.lock = self._lock
  94. @property
  95. def _lock(self):
  96. return FileLock(self.directory / '.lock')
  97. def __copy__(self):
  98. return self
  99. def __deepcopy__(self, memo):
  100. return self
  101. # https://docs.python.org/3/library/pickle.html#handling-stateful-objects
  102. def __getstate__(self):
  103. state = self.__dict__.copy()
  104. state.pop('lock')
  105. return state
  106. def __setstate__(self, state):
  107. self.__dict__.update(state)
  108. self.lock = self._lock
  109. class FileLock(_FileLock):
  110. def _acquire(self, *args, **kwargs):
  111. Path(self._lock_file).parent.mkdir(parents=True, exist_ok=True)
  112. super()._acquire(*args, **kwargs)