data_utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from fate_arch.abc import StorageTableMetaABC, AddressABC
  17. from fate_arch.common.address import MysqlAddress, HiveAddress
  18. from fate_arch.common.data_utils import default_output_fs_path
  19. from fate_arch.computing import ComputingEngine
  20. from fate_arch.storage import StorageEngine, StorageTableMeta
  21. from fate_flow.entity.types import InputSearchType
  22. from fate_arch import storage
  23. def get_header_schema(header_line, id_delimiter, extend_sid=False):
  24. header_source_item = header_line.split(id_delimiter)
  25. if extend_sid:
  26. header = id_delimiter.join(header_source_item).strip()
  27. sid = get_extend_id_name()
  28. else:
  29. header = id_delimiter.join(header_source_item[1:]).strip()
  30. sid = header_source_item[0].strip()
  31. return {'header': header, 'sid': sid}
  32. def get_extend_id_name():
  33. return "extend_sid"
  34. def get_sid_data_line(values, id_delimiter, fate_uuid, line_index, **kwargs):
  35. return line_extend_uuid(fate_uuid, line_index), list_to_str(values, id_delimiter=id_delimiter)
  36. def line_extend_uuid(fate_uuid, line_index):
  37. return fate_uuid + str(line_index)
  38. def get_auto_increasing_sid_data_line(values, id_delimiter, line_index, **kwargs):
  39. return line_index, list_to_str(values, id_delimiter=id_delimiter)
  40. def get_data_line(values, id_delimiter, **kwargs):
  41. return values[0], list_to_str(values[1:], id_delimiter=id_delimiter)
  42. def list_to_str(input_list, id_delimiter):
  43. return id_delimiter.join(list(map(str, input_list)))
  44. def convert_output(
  45. input_name,
  46. input_namespace,
  47. output_name,
  48. output_namespace,
  49. computing_engine: ComputingEngine = ComputingEngine.EGGROLL,
  50. output_storage_address={},
  51. ) -> (StorageTableMetaABC, AddressABC, StorageEngine):
  52. input_table_meta = StorageTableMeta(name=input_name, namespace=input_namespace)
  53. if not input_table_meta:
  54. raise RuntimeError(
  55. f"can not found table name: {input_name} namespace: {input_namespace}"
  56. )
  57. address_dict = output_storage_address.copy()
  58. if input_table_meta.get_engine() in [StorageEngine.PATH]:
  59. from fate_arch.storage import PathStoreType
  60. address_dict["name"] = output_name
  61. address_dict["namespace"] = output_namespace
  62. address_dict["storage_type"] = PathStoreType.PICTURE
  63. address_dict["path"] = input_table_meta.get_address().path
  64. output_table_address = StorageTableMeta.create_address(
  65. storage_engine=StorageEngine.PATH, address_dict=address_dict
  66. )
  67. output_table_engine = StorageEngine.PATH
  68. elif computing_engine == ComputingEngine.STANDALONE:
  69. from fate_arch.storage import StandaloneStoreType
  70. address_dict["name"] = output_name
  71. address_dict["namespace"] = output_namespace
  72. address_dict["storage_type"] = StandaloneStoreType.ROLLPAIR_LMDB
  73. output_table_address = StorageTableMeta.create_address(
  74. storage_engine=StorageEngine.STANDALONE, address_dict=address_dict
  75. )
  76. output_table_engine = StorageEngine.STANDALONE
  77. elif computing_engine == ComputingEngine.EGGROLL:
  78. from fate_arch.storage import EggRollStoreType
  79. address_dict["name"] = output_name
  80. address_dict["namespace"] = output_namespace
  81. address_dict["storage_type"] = EggRollStoreType.ROLLPAIR_LMDB
  82. output_table_address = StorageTableMeta.create_address(
  83. storage_engine=StorageEngine.EGGROLL, address_dict=address_dict
  84. )
  85. output_table_engine = StorageEngine.EGGROLL
  86. elif computing_engine == ComputingEngine.SPARK:
  87. if input_table_meta.get_engine() == StorageEngine.HIVE:
  88. output_table_address = input_table_meta.get_address()
  89. output_table_address.name = output_name
  90. output_table_engine = input_table_meta.get_engine()
  91. elif input_table_meta.get_engine() == StorageEngine.LOCALFS:
  92. output_table_address = input_table_meta.get_address()
  93. output_table_address.path = default_output_fs_path(
  94. name=output_name,
  95. namespace=output_namespace,
  96. storage_engine=StorageEngine.LOCALFS
  97. )
  98. output_table_engine = input_table_meta.get_engine()
  99. else:
  100. address_dict["path"] = default_output_fs_path(
  101. name=output_name,
  102. namespace=output_namespace,
  103. prefix=address_dict.get("path_prefix"),
  104. storage_engine=StorageEngine.HDFS
  105. )
  106. output_table_address = StorageTableMeta.create_address(
  107. storage_engine=StorageEngine.HDFS, address_dict=address_dict
  108. )
  109. output_table_engine = StorageEngine.HDFS
  110. elif computing_engine == ComputingEngine.LINKIS_SPARK:
  111. output_table_address = input_table_meta.get_address()
  112. output_table_address.name = output_name
  113. output_table_engine = input_table_meta.get_engine()
  114. else:
  115. raise RuntimeError(f"can not support computing engine {computing_engine}")
  116. return input_table_meta, output_table_address, output_table_engine
  117. def get_input_data_min_partitions(input_data, role, party_id):
  118. min_partition = None
  119. if role != 'arbiter':
  120. for data_type, data_location in input_data[role][party_id].items():
  121. table_info = {'name': data_location.split('.')[1], 'namespace': data_location.split('.')[0]}
  122. table_meta = storage.StorageTableMeta(name=table_info['name'], namespace=table_info['namespace'])
  123. if table_meta:
  124. table_partition = table_meta.get_partitions()
  125. if not min_partition or min_partition > table_partition:
  126. min_partition = table_partition
  127. return min_partition
  128. def get_input_search_type(parameters):
  129. if "name" in parameters and "namespace" in parameters:
  130. return InputSearchType.TABLE_INFO
  131. elif "job_id" in parameters and "component_name" in parameters and "data_name" in parameters:
  132. return InputSearchType.JOB_COMPONENT_OUTPUT
  133. else:
  134. return InputSearchType.UNKNOWN
  135. def address_filter(address):
  136. if isinstance(address, MysqlAddress):
  137. address.passwd = None
  138. if isinstance(address, HiveAddress):
  139. address.password = None
  140. return address.__dict__