_table.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch.storage import StorageEngine, MySQLStoreType
  17. from fate_arch.storage import StorageTableBase
  18. class StorageTable(StorageTableBase):
  19. def __init__(
  20. self,
  21. cur,
  22. con,
  23. address=None,
  24. name: str = None,
  25. namespace: str = None,
  26. partitions: int = 1,
  27. store_type: MySQLStoreType = MySQLStoreType.InnoDB,
  28. options=None,
  29. ):
  30. super(StorageTable, self).__init__(
  31. name=name,
  32. namespace=namespace,
  33. address=address,
  34. partitions=partitions,
  35. options=options,
  36. engine=StorageEngine.MYSQL,
  37. store_type=store_type,
  38. )
  39. self._cur = cur
  40. self._con = con
  41. def check_address(self):
  42. schema = self.meta.get_schema()
  43. if schema:
  44. if schema.get("sid") and schema.get("header"):
  45. sql = "SELECT {},{} FROM {}".format(
  46. schema.get("sid"), schema.get("header"), self._address.name
  47. )
  48. else:
  49. sql = "SELECT {} FROM {}".format(
  50. schema.get("sid"), self._address.name
  51. )
  52. feature_data = self.execute(sql)
  53. for feature in feature_data:
  54. if feature:
  55. break
  56. return True
  57. @staticmethod
  58. def get_meta_header(feature_name_list):
  59. create_features = ""
  60. feature_list = []
  61. feature_size = "varchar(255)"
  62. for feature_name in feature_name_list:
  63. create_features += "{} {},".format(feature_name, feature_size)
  64. feature_list.append(feature_name)
  65. return create_features, feature_list
  66. def _count(self):
  67. sql = "select count(*) from {}".format(self._address.name)
  68. try:
  69. self._cur.execute(sql)
  70. # self.con.commit()
  71. ret = self._cur.fetchall()
  72. count = ret[0][0]
  73. except BaseException:
  74. count = 0
  75. return count
  76. def _collect(self, **kwargs) -> list:
  77. id_name, feature_name_list, _ = self._get_id_feature_name()
  78. id_feature_name = [id_name]
  79. id_feature_name.extend(feature_name_list)
  80. sql = "select {} from {}".format(",".join(id_feature_name), self._address.name)
  81. data = self.execute(sql)
  82. for line in data:
  83. feature_list = [str(feature) for feature in list(line[1:])]
  84. yield line[0], self.meta.get_id_delimiter().join(feature_list)
  85. def _put_all(self, kv_list, **kwargs):
  86. id_name, feature_name_list, id_delimiter = self._get_id_feature_name()
  87. feature_sql, feature_list = StorageTable.get_meta_header(feature_name_list)
  88. id_size = "varchar(100)"
  89. create_table = (
  90. "create table if not exists {}({} {} NOT NULL, {} PRIMARY KEY({}))".format(
  91. self._address.name, id_name, id_size, feature_sql, id_name
  92. )
  93. )
  94. self._cur.execute(create_table)
  95. sql = "REPLACE INTO {}({}, {}) VALUES".format(
  96. self._address.name, id_name, ",".join(feature_list)
  97. )
  98. for kv in kv_list:
  99. sql += '("{}", "{}"),'.format(kv[0], '", "'.join(kv[1].split(id_delimiter)))
  100. sql = ",".join(sql.split(",")[:-1]) + ";"
  101. self._cur.execute(sql)
  102. self._con.commit()
  103. def _destroy(self):
  104. sql = "drop table {}".format(self._address.name)
  105. self._cur.execute(sql)
  106. self._con.commit()
  107. def _save_as(self, address, name, namespace, partitions=None, **kwargs):
  108. sql = "create table {}.{} select * from {};".format(namespace, name, self._address.name)
  109. self._cur.execute(sql)
  110. self._con.commit()
  111. def execute(self, sql, select=True):
  112. self._cur.execute(sql)
  113. if select:
  114. while True:
  115. result = self._cur.fetchone()
  116. if result:
  117. yield result
  118. else:
  119. break
  120. else:
  121. result = self._cur.fetchall()
  122. return result
  123. def _get_id_feature_name(self):
  124. id = self.meta.get_schema().get("sid", "id")
  125. header = self.meta.get_schema().get("header", [])
  126. id_delimiter = self.meta.get_id_delimiter()
  127. if not header:
  128. feature_list = []
  129. elif isinstance(header, str):
  130. feature_list = header.split(id_delimiter)
  131. elif isinstance(header, list):
  132. feature_list = header
  133. else:
  134. feature_list = [header]
  135. if self.meta.get_extend_sid():
  136. id = feature_list[0]
  137. if len(feature_list) > 1:
  138. feature_list = feature_list[1:]
  139. return id, feature_list, id_delimiter