_session.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import traceback
  17. from impala.dbapi import connect
  18. from fate_arch.common.address import HiveAddress
  19. from fate_arch.storage import StorageSessionBase, StorageEngine, HiveStoreType
  20. from fate_arch.abc import AddressABC
  21. class StorageSession(StorageSessionBase):
  22. def __init__(self, session_id, options=None):
  23. super(StorageSession, self).__init__(session_id=session_id, engine=StorageEngine.HIVE)
  24. self._db_con = {}
  25. def table(self, name, namespace, address: AddressABC, partitions,
  26. storage_type: HiveStoreType = HiveStoreType.DEFAULT, options=None, **kwargs):
  27. if isinstance(address, HiveAddress):
  28. from fate_arch.storage.hive._table import StorageTable
  29. address_key = HiveAddress(
  30. host=address.host,
  31. username=None,
  32. port=address.port,
  33. database=address.database,
  34. auth_mechanism=None,
  35. password=None,
  36. name=None)
  37. if address_key in self._db_con:
  38. con, cur = self._db_con[address_key]
  39. else:
  40. self._create_db_if_not_exists(address)
  41. con = connect(host=address.host,
  42. port=address.port,
  43. database=address.database,
  44. auth_mechanism=address.auth_mechanism,
  45. password=address.password,
  46. user=address.username
  47. )
  48. cur = con.cursor()
  49. self._db_con[address_key] = (con, cur)
  50. return StorageTable(cur=cur, con=con, address=address, name=name, namespace=namespace,
  51. storage_type=storage_type, partitions=partitions, options=options)
  52. raise NotImplementedError(f"address type {type(address)} not supported with eggroll storage")
  53. def cleanup(self, name, namespace):
  54. pass
  55. def stop(self):
  56. try:
  57. for key, val in self._db_con.items():
  58. con = val[0]
  59. cur = val[1]
  60. cur.close()
  61. con.close()
  62. except Exception as e:
  63. traceback.print_exc()
  64. def kill(self):
  65. return self.stop()
  66. def _create_db_if_not_exists(self, address):
  67. connection = connect(host=address.host,
  68. port=address.port,
  69. user=address.username,
  70. auth_mechanism=address.auth_mechanism,
  71. password=address.password
  72. )
  73. with connection:
  74. with connection.cursor() as cursor:
  75. cursor.execute("create database if not exists {}".format(address.database))
  76. print('create db {} success'.format(address.database))
  77. connection.commit()