_csession.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from eggroll.core.session import session_init
  17. from eggroll.roll_pair.roll_pair import runtime_init
  18. from fate_arch.abc import AddressABC, CSessionABC
  19. from fate_arch.common.base_utils import fate_uuid
  20. from fate_arch.common.log import getLogger
  21. from fate_arch.common.profile import computing_profile
  22. from fate_arch.computing.eggroll import Table
  23. LOGGER = getLogger()
  24. class CSession(CSessionABC):
  25. def __init__(self, session_id, options: dict = None):
  26. if options is None:
  27. options = {}
  28. if "eggroll.session.deploy.mode" not in options:
  29. options["eggroll.session.deploy.mode"] = "cluster"
  30. if "eggroll.rollpair.inmemory_output" not in options:
  31. options["eggroll.rollpair.inmemory_output"] = True
  32. self._rp_session = session_init(session_id=session_id, options=options)
  33. self._rpc = runtime_init(session=self._rp_session)
  34. self._session_id = self._rp_session.get_session_id()
  35. def get_rpc(self):
  36. return self._rpc
  37. @property
  38. def session_id(self):
  39. return self._session_id
  40. @computing_profile
  41. def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs):
  42. from fate_arch.common.address import EggRollAddress
  43. from fate_arch.storage import EggRollStoreType
  44. if isinstance(address, EggRollAddress):
  45. options = kwargs.get("option", {})
  46. options["total_partitions"] = partitions
  47. options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB)
  48. options["create_if_missing"] = False
  49. rp = self._rpc.load(
  50. namespace=address.namespace, name=address.name, options=options
  51. )
  52. if rp is None or rp.get_partitions() == 0:
  53. raise RuntimeError(
  54. f"no exists: {address.name}, {address.namespace}"
  55. )
  56. if options["store_type"] != EggRollStoreType.ROLLPAIR_IN_MEMORY:
  57. rp = rp.save_as(
  58. name=f"{address.name}_{fate_uuid()}",
  59. namespace=self.session_id,
  60. partition=partitions,
  61. options={"store_type": EggRollStoreType.ROLLPAIR_IN_MEMORY},
  62. )
  63. table = Table(rp=rp)
  64. table.schema = schema
  65. return table
  66. from fate_arch.common.address import PathAddress
  67. if isinstance(address, PathAddress):
  68. from fate_arch.computing.non_distributed import LocalData
  69. from fate_arch.computing import ComputingEngine
  70. return LocalData(address.path, engine=ComputingEngine.EGGROLL)
  71. raise NotImplementedError(
  72. f"address type {type(address)} not supported with eggroll backend"
  73. )
  74. @computing_profile
  75. def parallelize(self, data, partition: int, include_key: bool, **kwargs) -> Table:
  76. options = dict()
  77. options["total_partitions"] = partition
  78. options["include_key"] = include_key
  79. rp = self._rpc.parallelize(data=data, options=options)
  80. return Table(rp)
  81. def cleanup(self, name, namespace):
  82. self._rpc.cleanup(name=name, namespace=namespace)
  83. def stop(self):
  84. return self._rp_session.stop()
  85. def kill(self):
  86. return self._rp_session.kill()
  87. def destroy(self):
  88. try:
  89. LOGGER.info(f"clean table namespace {self.session_id}")
  90. self.cleanup(namespace=self.session_id, name="*")
  91. except Exception as e:
  92. LOGGER.warning(f"no found table namespace {self.session_id}")
  93. try:
  94. self.stop()
  95. except Exception as e:
  96. LOGGER.warning(f"stop storage session {self.session_id} failed, try to kill", e)
  97. self.kill()