_table.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import uuid
  18. from fate_arch.common import hive_utils
  19. from fate_arch.common.file_utils import get_project_base_directory
  20. from fate_arch.storage import StorageEngine, HiveStoreType
  21. from fate_arch.storage import StorageTableBase
  22. class StorageTable(StorageTableBase):
  23. def __init__(
  24. self,
  25. cur,
  26. con,
  27. address=None,
  28. name: str = None,
  29. namespace: str = None,
  30. partitions: int = 1,
  31. storage_type: HiveStoreType = HiveStoreType.DEFAULT,
  32. options=None,
  33. ):
  34. super(StorageTable, self).__init__(
  35. name=name,
  36. namespace=namespace,
  37. address=address,
  38. partitions=partitions,
  39. options=options,
  40. engine=StorageEngine.HIVE,
  41. store_type=storage_type,
  42. )
  43. self._cur = cur
  44. self._con = con
  45. def execute(self, sql, select=True):
  46. self._cur.execute(sql)
  47. if select:
  48. while True:
  49. result = self._cur.fetchone()
  50. if result:
  51. yield result
  52. else:
  53. break
  54. else:
  55. result = self._cur.fetchall()
  56. return result
  57. def _count(self, **kwargs):
  58. sql = 'select count(*) from {}'.format(self._address.name)
  59. try:
  60. self._cur.execute(sql)
  61. self._con.commit()
  62. ret = self._cur.fetchall()
  63. count = ret[0][0]
  64. except BaseException:
  65. count = 0
  66. return count
  67. def _collect(self, **kwargs) -> list:
  68. sql = "select * from {}".format(self._address.name)
  69. data = self.execute(sql)
  70. for line in data:
  71. yield hive_utils.deserialize_line(line)
  72. def _read(self) -> list:
  73. id_name, feature_name_list, _ = self._get_id_feature_name()
  74. id_feature_name = [id_name]
  75. id_feature_name.extend(feature_name_list)
  76. sql = "select {} from {}".format(",".join(id_feature_name), self._address.name)
  77. data = self.execute(sql)
  78. for line in data:
  79. yield hive_utils.read_line(line)
  80. def _put_all(self, kv_list, **kwargs):
  81. id_name, feature_name_list, id_delimiter = self.get_id_feature_name()
  82. create_table = "create table if not exists {}(k varchar(128) NOT NULL, v string) row format delimited fields terminated by" \
  83. " '{}'".format(self._address.name, id_delimiter)
  84. self._cur.execute(create_table)
  85. # load local file or hdfs file
  86. temp_path = os.path.join(get_project_base_directory(), 'temp_data', uuid.uuid1().hex)
  87. os.makedirs(os.path.dirname(temp_path), exist_ok=True)
  88. with open(temp_path, 'w') as f:
  89. for k, v in kv_list:
  90. f.write(hive_utils.serialize_line(k, v))
  91. sql = "load data local inpath '{}' into table {}".format(temp_path, self._address.name)
  92. self._cur.execute(sql)
  93. self._con.commit()
  94. os.remove(temp_path)
  95. def get_id_feature_name(self):
  96. id = self.meta.get_schema().get('sid', 'id')
  97. header = self.meta.get_schema().get('header')
  98. id_delimiter = self.meta.get_id_delimiter()
  99. if header:
  100. if isinstance(header, str):
  101. feature_list = header.split(id_delimiter)
  102. elif isinstance(header, list):
  103. feature_list = header
  104. else:
  105. feature_list = [header]
  106. else:
  107. raise Exception("hive table need data header")
  108. return id, feature_list, id_delimiter
  109. def _destroy(self):
  110. sql = "drop table {}".format(self._name)
  111. return self.execute(sql)
  112. def _save_as(self, address, name, namespace, partitions=None, **kwargs):
  113. sql = "create table {}.{} like {}.{};".format(namespace, name, self._namespace, self._name)
  114. return self.execute(sql)
  115. def check_address(self):
  116. schema = self.meta.get_schema()
  117. if schema:
  118. sql = 'SELECT {},{} FROM {}'.format(schema.get('sid'), schema.get('header'), self._address.name)
  119. feature_data = self.execute(sql)
  120. for feature in feature_data:
  121. if feature:
  122. return True
  123. return False
  124. @staticmethod
  125. def get_meta_header(feature_name_list):
  126. create_features = ''
  127. feature_list = []
  128. feature_size = "varchar(255)"
  129. for feature_name in feature_name_list:
  130. create_features += '{} {},'.format(feature_name, feature_size)
  131. feature_list.append(feature_name)
  132. return create_features, feature_list
  133. def _get_id_feature_name(self):
  134. id = self.meta.get_schema().get("sid", "id")
  135. header = self.meta.get_schema().get("header")
  136. id_delimiter = self.meta.get_id_delimiter()
  137. if header:
  138. if isinstance(header, str):
  139. feature_list = header.split(id_delimiter)
  140. elif isinstance(header, list):
  141. feature_list = header
  142. else:
  143. feature_list = [header]
  144. else:
  145. raise Exception("mysql table need data header")
  146. return id, feature_list, id_delimiter