123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import operator
- import typing
- from enum import IntEnum
- from peewee import Field, IntegerField, FloatField, BigIntegerField, TextField, Model, CompositeKey, Metadata
- from fate_arch.common import conf_utils, EngineType
- from fate_arch.common.base_utils import current_timestamp, serialize_b64, deserialize_b64, timestamp_to_date, date_string_to_timestamp, json_dumps, json_loads
- from fate_arch.federation import FederationEngine
- is_standalone = conf_utils.get_base_config("default_engines", {}).get(
- EngineType.FEDERATION).upper() == FederationEngine.STANDALONE
- if is_standalone:
- from playhouse.apsw_ext import DateTimeField
- else:
- from peewee import DateTimeField
- CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
- AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
- class SerializedType(IntEnum):
- PICKLE = 1
- JSON = 2
- class LongTextField(TextField):
- field_type = 'LONGTEXT'
- class JSONField(LongTextField):
- default_value = {}
- def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
- self._object_hook = object_hook
- self._object_pairs_hook = object_pairs_hook
- super().__init__(**kwargs)
- def db_value(self, value):
- if value is None:
- value = self.default_value
- return json_dumps(value)
- def python_value(self, value):
- if not value:
- return self.default_value
- return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
- class ListField(JSONField):
- default_value = []
- class SerializedField(LongTextField):
- def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
- self._serialized_type = serialized_type
- self._object_hook = object_hook
- self._object_pairs_hook = object_pairs_hook
- super().__init__(**kwargs)
- def db_value(self, value):
- if self._serialized_type == SerializedType.PICKLE:
- return serialize_b64(value, to_str=True)
- elif self._serialized_type == SerializedType.JSON:
- if value is None:
- return None
- return json_dumps(value, with_type=True)
- else:
- raise ValueError(f"the serialized type {self._serialized_type} is not supported")
- def python_value(self, value):
- if self._serialized_type == SerializedType.PICKLE:
- return deserialize_b64(value)
- elif self._serialized_type == SerializedType.JSON:
- if value is None:
- return {}
- return json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
- else:
- raise ValueError(f"the serialized type {self._serialized_type} is not supported")
- def is_continuous_field(cls: typing.Type) -> bool:
- if cls in CONTINUOUS_FIELD_TYPE:
- return True
- for p in cls.__bases__:
- if p in CONTINUOUS_FIELD_TYPE:
- return True
- elif p != Field and p != object:
- if is_continuous_field(p):
- return True
- else:
- return False
- def auto_date_timestamp_field():
- return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
- def auto_date_timestamp_db_field():
- return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
- def remove_field_name_prefix(field_name):
- return field_name[2:] if field_name.startswith('f_') else field_name
- class BaseModel(Model):
- f_create_time = BigIntegerField(null=True)
- f_create_date = DateTimeField(null=True)
- f_update_time = BigIntegerField(null=True)
- f_update_date = DateTimeField(null=True)
- def to_json(self):
- # This function is obsolete
- return self.to_dict()
- def to_dict(self):
- return self.__dict__['__data__']
- def to_human_model_dict(self, only_primary_with: list = None):
- model_dict = self.__dict__['__data__']
- if not only_primary_with:
- return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
- human_model_dict = {}
- for k in self._meta.primary_key.field_names:
- human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
- for k in only_primary_with:
- human_model_dict[k] = model_dict[f'f_{k}']
- return human_model_dict
- @property
- def meta(self) -> Metadata:
- return self._meta
- @classmethod
- def get_primary_keys_name(cls):
- return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
- cls._meta.primary_key.name]
- @classmethod
- def getter_by(cls, attr):
- return operator.attrgetter(attr)(cls)
- @classmethod
- def query(cls, reverse=None, order_by=None, **kwargs):
- filters = []
- for f_n, f_v in kwargs.items():
- attr_name = 'f_%s' % f_n
- if not hasattr(cls, attr_name) or f_v is None:
- continue
- if type(f_v) in {list, set}:
- f_v = list(f_v)
- if is_continuous_field(type(getattr(cls, attr_name))):
- if len(f_v) == 2:
- for i, v in enumerate(f_v):
- if isinstance(v, str) and f_n in auto_date_timestamp_field():
- # time type: %Y-%m-%d %H:%M:%S
- f_v[i] = date_string_to_timestamp(v)
- lt_value = f_v[0]
- gt_value = f_v[1]
- if lt_value is not None and gt_value is not None:
- filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
- elif lt_value is not None:
- filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
- elif gt_value is not None:
- filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
- else:
- filters.append(operator.attrgetter(attr_name)(cls) << f_v)
- else:
- filters.append(operator.attrgetter(attr_name)(cls) == f_v)
- if filters:
- query_records = cls.select().where(*filters)
- if reverse is not None:
- if not order_by or not hasattr(cls, f"f_{order_by}"):
- order_by = "create_time"
- if reverse is True:
- query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").desc())
- elif reverse is False:
- query_records = query_records.order_by(cls.getter_by(f"f_{order_by}").asc())
- return [query_record for query_record in query_records]
- else:
- return []
- @classmethod
- def insert(cls, __data=None, **insert):
- if isinstance(__data, dict) and __data:
- __data[cls._meta.combined["f_create_time"]] = current_timestamp()
- if insert:
- insert["f_create_time"] = current_timestamp()
- return super().insert(__data, **insert)
- # update and insert will call this method
- @classmethod
- def _normalize_data(cls, data, kwargs):
- normalized = super()._normalize_data(data, kwargs)
- if not normalized:
- return {}
- normalized[cls._meta.combined["f_update_time"]] = current_timestamp()
- for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
- if {f"f_{f_n}_time", f"f_{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
- cls._meta.combined[f"f_{f_n}_time"] in normalized and \
- normalized[cls._meta.combined[f"f_{f_n}_time"]] is not None:
- normalized[cls._meta.combined[f"f_{f_n}_date"]] = timestamp_to_date(
- normalized[cls._meta.combined[f"f_{f_n}_time"]])
- return normalized
|