123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import traceback
- import pymysql
- from fate_arch.common import log
- from fate_arch.storage import StorageTableBase
- from fate_flow.component_env_utils import feature_utils
- from fate_flow.external.storage.base import Storage, MysqlAddress
- from fate_flow.manager.data_manager import get_component_output_data_schema
- LOGGER = log.getLogger()
- class MysqlStorage(Storage):
- def __init__(self, address: dict, storage_table: StorageTableBase):
- self.address = MysqlAddress(**address)
- self.storage_table = storage_table
- self._con = None
- self._cur = None
- self._connect()
- def save(self):
- create = False
- sql = None
- max = 10000
- count = 0
- LOGGER.info(f"start save Table({self.storage_table.namespace}, {self.storage_table.name}) to Mysql({self.address.db}, {self.address.name})")
- join_delimiter = ","
- for k, v in self.storage_table.collect():
- v, extend_header = feature_utils.get_deserialize_value(v, join_delimiter)
- if not create:
- _, header_list = self._create_table(extend_header)
- LOGGER.info("craete table success")
- create = True
- if not sql:
- sql = "REPLACE INTO {}({}, {}) VALUES".format(
- self.address.name, header_list[0], ",".join(header_list[1:])
- )
- sql += '("{}", "{}"),'.format(k, '", "'.join(v.split(join_delimiter)))
- count += 1
- if not count % max:
- sql = ",".join(sql.split(",")[:-1]) + ";"
- self._cur.execute(sql)
- self._con.commit()
- sql = None
- LOGGER.info(f"save data count:{count}")
- if count > 0:
- sql = ",".join(sql.split(",")[:-1]) + ";"
- self._cur.execute(sql)
- self._con.commit()
- LOGGER.info(f"save success, count:{count}")
- def _create_table(self, extend_header):
- header_list = get_component_output_data_schema(self.storage_table.meta, extend_header)
- feature_sql = self.get_create_features_sql(header_list[1:])
- id_size = "varchar(100)"
- create_table = (
- "create table if not exists {}({} {} NOT NULL, {} PRIMARY KEY({}))".format(
- self.address.name, header_list[0], id_size, feature_sql, header_list[0]
- )
- )
- LOGGER.info(f"create table {self.address.name}: {create_table}")
- return self._cur.execute(create_table), header_list
- @staticmethod
- def get_create_features_sql(feature_name_list):
- create_features = ""
- feature_list = []
- feature_size = "varchar(255)"
- for feature_name in feature_name_list:
- create_features += "{} {},".format(feature_name, feature_size)
- feature_list.append(feature_name)
- return create_features
- def _create_db_if_not_exists(self):
- connection = pymysql.connect(host=self.address.host,
- user=self.address.user,
- password=self.address.passwd,
- port=self.address.port)
- with connection:
- with connection.cursor() as cursor:
- cursor.execute("create database if not exists {}".format(self.address.db))
- print('create db {} success'.format(self.address.db))
- connection.commit()
- def _connect(self):
- LOGGER.info(f"start connect database {self.address.db}")
- self._con = pymysql.connect(host=self.address.host,
- user=self.address.user,
- passwd=self.address.passwd,
- port=self.address.port,
- db=self.address.db)
- self._cur = self._con.cursor()
- LOGGER.info(f"connect success!")
- def _open(self):
- return self
- def __enter__(self):
- self._connect()
- return self._open()
- def __exit__(self, exc_type, exc_val, exc_tb):
- try:
- LOGGER.info("close connect")
- self._cur.close()
- self._con.close()
- except Exception as e:
- traceback.print_exc()
|