mysql.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import traceback
  17. import pymysql
  18. from fate_arch.common import log
  19. from fate_arch.storage import StorageTableBase
  20. from fate_flow.component_env_utils import feature_utils
  21. from fate_flow.external.storage.base import Storage, MysqlAddress
  22. from fate_flow.manager.data_manager import get_component_output_data_schema
  23. LOGGER = log.getLogger()
  24. class MysqlStorage(Storage):
  25. def __init__(self, address: dict, storage_table: StorageTableBase):
  26. self.address = MysqlAddress(**address)
  27. self.storage_table = storage_table
  28. self._con = None
  29. self._cur = None
  30. self._connect()
  31. def save(self):
  32. create = False
  33. sql = None
  34. max = 10000
  35. count = 0
  36. LOGGER.info(f"start save Table({self.storage_table.namespace}, {self.storage_table.name}) to Mysql({self.address.db}, {self.address.name})")
  37. join_delimiter = ","
  38. for k, v in self.storage_table.collect():
  39. v, extend_header = feature_utils.get_deserialize_value(v, join_delimiter)
  40. if not create:
  41. _, header_list = self._create_table(extend_header)
  42. LOGGER.info("craete table success")
  43. create = True
  44. if not sql:
  45. sql = "REPLACE INTO {}({}, {}) VALUES".format(
  46. self.address.name, header_list[0], ",".join(header_list[1:])
  47. )
  48. sql += '("{}", "{}"),'.format(k, '", "'.join(v.split(join_delimiter)))
  49. count += 1
  50. if not count % max:
  51. sql = ",".join(sql.split(",")[:-1]) + ";"
  52. self._cur.execute(sql)
  53. self._con.commit()
  54. sql = None
  55. LOGGER.info(f"save data count:{count}")
  56. if count > 0:
  57. sql = ",".join(sql.split(",")[:-1]) + ";"
  58. self._cur.execute(sql)
  59. self._con.commit()
  60. LOGGER.info(f"save success, count:{count}")
  61. def _create_table(self, extend_header):
  62. header_list = get_component_output_data_schema(self.storage_table.meta, extend_header)
  63. feature_sql = self.get_create_features_sql(header_list[1:])
  64. id_size = "varchar(100)"
  65. create_table = (
  66. "create table if not exists {}({} {} NOT NULL, {} PRIMARY KEY({}))".format(
  67. self.address.name, header_list[0], id_size, feature_sql, header_list[0]
  68. )
  69. )
  70. LOGGER.info(f"create table {self.address.name}: {create_table}")
  71. return self._cur.execute(create_table), header_list
  72. @staticmethod
  73. def get_create_features_sql(feature_name_list):
  74. create_features = ""
  75. feature_list = []
  76. feature_size = "varchar(255)"
  77. for feature_name in feature_name_list:
  78. create_features += "{} {},".format(feature_name, feature_size)
  79. feature_list.append(feature_name)
  80. return create_features
  81. def _create_db_if_not_exists(self):
  82. connection = pymysql.connect(host=self.address.host,
  83. user=self.address.user,
  84. password=self.address.passwd,
  85. port=self.address.port)
  86. with connection:
  87. with connection.cursor() as cursor:
  88. cursor.execute("create database if not exists {}".format(self.address.db))
  89. print('create db {} success'.format(self.address.db))
  90. connection.commit()
  91. def _connect(self):
  92. LOGGER.info(f"start connect database {self.address.db}")
  93. self._con = pymysql.connect(host=self.address.host,
  94. user=self.address.user,
  95. passwd=self.address.passwd,
  96. port=self.address.port,
  97. db=self.address.db)
  98. self._cur = self._con.cursor()
  99. LOGGER.info(f"connect success!")
  100. def _open(self):
  101. return self
  102. def __enter__(self):
  103. self._connect()
  104. return self._open()
  105. def __exit__(self, exc_type, exc_val, exc_tb):
  106. try:
  107. LOGGER.info("close connect")
  108. self._cur.close()
  109. self._con.close()
  110. except Exception as e:
  111. traceback.print_exc()