base_model.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import operator
  17. import typing
  18. from enum import IntEnum
  19. from peewee import Field, IntegerField, FloatField, BigIntegerField, TextField, Model, CompositeKey, Metadata
  20. from fate_arch.common import conf_utils, EngineType
  21. from fate_arch.common.base_utils import current_timestamp, serialize_b64, deserialize_b64, timestamp_to_date, date_string_to_timestamp, json_dumps, json_loads
  22. from fate_arch.federation import FederationEngine
  23. is_standalone = conf_utils.get_base_config("default_engines", {}).get(
  24. EngineType.FEDERATION).upper() == FederationEngine.STANDALONE
  25. if is_standalone:
  26. from playhouse.apsw_ext import DateTimeField
  27. else:
  28. from peewee import DateTimeField
  29. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  30. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
  31. class SerializedType(IntEnum):
  32. PICKLE = 1
  33. JSON = 2
  34. class LongTextField(TextField):
  35. field_type = 'LONGTEXT'
  36. class JSONField(LongTextField):
  37. default_value = {}
  38. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  39. self._object_hook = object_hook
  40. self._object_pairs_hook = object_pairs_hook
  41. super().__init__(**kwargs)
  42. def db_value(self, value):
  43. if value is None:
  44. value = self.default_value
  45. return json_dumps(value)
  46. def python_value(self, value):
  47. if not value:
  48. return self.default_value
  49. return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  50. class ListField(JSONField):
  51. default_value = []
  52. class SerializedField(LongTextField):
  53. def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
  54. self._serialized_type = serialized_type
  55. self._object_hook = object_hook
  56. self._object_pairs_hook = object_pairs_hook
  57. super().__init__(**kwargs)
  58. def db_value(self, value):
  59. if self._serialized_type == SerializedType.PICKLE:
  60. return serialize_b64(value, to_str=True)
  61. elif self._serialized_type == SerializedType.JSON:
  62. if value is None:
  63. return None
  64. return json_dumps(value, with_type=True)
  65. else:
  66. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  67. def python_value(self, value):
  68. if self._serialized_type == SerializedType.PICKLE:
  69. return deserialize_b64(value)
  70. elif self._serialized_type == SerializedType.JSON:
  71. if value is None:
  72. return {}
  73. return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  74. else:
  75. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  76. def is_continuous_field(cls: typing.Type) -> bool:
  77. if cls in CONTINUOUS_FIELD_TYPE:
  78. return True
  79. for p in cls.__bases__:
  80. if p in CONTINUOUS_FIELD_TYPE:
  81. return True
  82. elif p != Field and p != object:
  83. if is_continuous_field(p):
  84. return True
  85. else:
  86. return False
  87. def auto_date_timestamp_field():
  88. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  89. def auto_date_timestamp_db_field():
  90. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  91. def remove_field_name_prefix(field_name):
  92. return field_name[2:] if field_name.startswith('f_') else field_name
  93. class BaseModel(Model):
  94. f_create_time = BigIntegerField(null=True)
  95. f_create_date = DateTimeField(null=True)
  96. f_update_time = BigIntegerField(null=True)
  97. f_update_date = DateTimeField(null=True)
  98. def to_json(self):
  99. # This function is obsolete
  100. return self.to_dict()
  101. def to_dict(self):
  102. return self.__dict__['__data__']
  103. def to_human_model_dict(self, only_primary_with: list = None):
  104. model_dict = self.__dict__['__data__']
  105. if not only_primary_with:
  106. return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
  107. human_model_dict = {}
  108. for k in self._meta.primary_key.field_names:
  109. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  110. for k in only_primary_with:
  111. human_model_dict[k] = model_dict[f'f_{k}']
  112. return human_model_dict
  113. @property
  114. def meta(self) -> Metadata:
  115. return self._meta
  116. @classmethod
  117. def get_primary_keys_name(cls):
  118. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
  119. cls._meta.primary_key.name]
  120. @classmethod
  121. def getter_by(cls, attr):
  122. return operator.attrgetter(attr)(cls)
  123. @classmethod
  124. def query(cls, reverse=None, order_by=None, **kwargs):
  125. filters = []
  126. for f_n, f_v in kwargs.items():
  127. attr_name = 'f_%s' % f_n
  128. if not hasattr(cls, attr_name) or f_v is None:
  129. continue
  130. if type(f_v) in {list, set}:
  131. f_v = list(f_v)
  132. if is_continuous_field(type(getattr(cls, attr_name))):
  133. if len(f_v) == 2:
  134. for i, v in enumerate(f_v):
  135. if isinstance(v, str) and f_n in auto_date_timestamp_field():
  136. # time type: %Y-%m-%d %H:%M:%S
  137. f_v[i] = date_string_to_timestamp(v)
  138. lt_value = f_v[0]
  139. gt_value = f_v[1]
  140. if lt_value is not None and gt_value is not None:
  141. filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
  142. elif lt_value is not None:
  143. filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
  144. elif gt_value is not None:
  145. filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
  146. else:
  147. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  148. else:
  149. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  150. if filters:
  151. query_records = cls.select().where(*filters)
  152. if reverse is not None:
  153. if not order_by or not hasattr(cls, f"f_{order_by}"):
  154. order_by = "create_time"
  155. if reverse is True:
  156. query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").desc())
  157. elif reverse is False:
  158. query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").asc())
  159. return [query_record for query_record in query_records]
  160. else:
  161. return []
  162. @classmethod
  163. def insert(cls, __data=None, **insert):
  164. if isinstance(__data, dict) and __data:
  165. __data[cls._meta.combined["f_create_time"]] = current_timestamp()
  166. if insert:
  167. insert["f_create_time"] = current_timestamp()
  168. return super().insert(__data, **insert)
  169. # update and insert will call this method
  170. @classmethod
  171. def _normalize_data(cls, data, kwargs):
  172. normalized = super()._normalize_data(data, kwargs)
  173. if not normalized:
  174. return {}
  175. normalized[cls._meta.combined["f_update_time"]] = current_timestamp()
  176. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  177. if {f"f_{f_n}_time", f"f_{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
  178. cls._meta.combined[f"f_{f_n}_time"] in normalized and \
  179. normalized[cls._meta.combined[f"f_{f_n}_time"]] is not None:
  180. normalized[cls._meta.combined[f"f_{f_n}_date"]] = timestamp_to_date(
  181. normalized[cls._meta.combined[f"f_{f_n}_time"]])
  182. return normalized