_csession.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #
  2. # Copyright 2019 The Eggroll Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from collections import Iterable
  17. from fate_arch._standalone import Session
  18. from fate_arch.abc import AddressABC, CSessionABC
  19. from fate_arch.common.base_utils import fate_uuid
  20. from fate_arch.common.log import getLogger
  21. from fate_arch.computing.standalone._table import Table
  22. LOGGER = getLogger()
  23. class CSession(CSessionABC):
  24. def __init__(self, session_id: str, options=None):
  25. if options is not None:
  26. max_workers = options.get("task_cores", None)
  27. self._session = Session(session_id, max_workers=max_workers)
  28. def get_standalone_session(self):
  29. return self._session
  30. @property
  31. def session_id(self):
  32. return self._session.session_id
  33. def load(self, address: AddressABC, partitions: int, schema: dict, **kwargs):
  34. from fate_arch.common.address import StandaloneAddress
  35. from fate_arch.storage import StandaloneStoreType
  36. if isinstance(address, StandaloneAddress):
  37. raw_table = self._session.load(address.name, address.namespace)
  38. if address.storage_type != StandaloneStoreType.ROLLPAIR_IN_MEMORY:
  39. raw_table = raw_table.save_as(
  40. name=f"{address.name}_{fate_uuid()}",
  41. namespace=address.namespace,
  42. partition=partitions,
  43. need_cleanup=True,
  44. )
  45. table = Table(raw_table)
  46. table.schema = schema
  47. return table
  48. from fate_arch.common.address import PathAddress
  49. if isinstance(address, PathAddress):
  50. from fate_arch.computing.non_distributed import LocalData
  51. from fate_arch.computing import ComputingEngine
  52. return LocalData(address.path, engine=ComputingEngine.STANDALONE)
  53. raise NotImplementedError(
  54. f"address type {type(address)} not supported with standalone backend"
  55. )
  56. def parallelize(self, data: Iterable, partition: int, include_key: bool, **kwargs):
  57. table = self._session.parallelize(
  58. data=data, partition=partition, include_key=include_key, **kwargs
  59. )
  60. return Table(table)
  61. def cleanup(self, name, namespace):
  62. return self._session.cleanup(name=name, namespace=namespace)
  63. def stop(self):
  64. return self._session.stop()
  65. def kill(self):
  66. return self._session.kill()
  67. def destroy(self):
  68. try:
  69. LOGGER.info(f"clean table namespace {self.session_id}")
  70. self.cleanup(namespace=self.session_id, name="*")
  71. except Exception as e:
  72. LOGGER.warning(f"no found table namespace {self.session_id}")
  73. try:
  74. self.stop()
  75. except Exception as e:
  76. LOGGER.warning(f"stop storage session {self.session_id} failed, try to kill", e)
  77. self.kill()