reader.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. from fate_arch import session
  20. from fate_arch.abc import AddressABC, StorageTableABC, StorageTableMetaABC
  21. from fate_arch.common import EngineType, log
  22. from fate_arch.common.data_utils import default_output_fs_path, default_output_info
  23. from fate_arch.computing import ComputingEngine
  24. from fate_arch.session import Session
  25. from fate_arch.storage import StorageEngine, StorageTableMeta, StorageTableOrigin
  26. from fate_flow.components._base import (
  27. BaseParam,
  28. ComponentBase,
  29. ComponentInputProtocol,
  30. ComponentMeta,
  31. )
  32. from fate_flow.errors import ParameterError
  33. from fate_flow.entity import MetricMeta
  34. from fate_flow.entity.types import InputSearchType
  35. from fate_flow.manager.data_manager import DataTableTracker, TableStorage, AnonymousGenerator
  36. from fate_flow.operation.job_tracker import Tracker
  37. from fate_flow.utils import data_utils
  38. LOGGER = log.getLogger()
  39. MAX_NUM = 10000
  40. reader_cpn_meta = ComponentMeta("Reader")
  41. @reader_cpn_meta.bind_param
  42. class ReaderParam(BaseParam):
  43. def __init__(self, table=None):
  44. self.table = table
  45. def check(self):
  46. return True
  47. @reader_cpn_meta.bind_runner.on_guest.on_host
  48. class Reader(ComponentBase):
  49. def __init__(self):
  50. super(Reader, self).__init__()
  51. self.parameters = None
  52. self.job_parameters = None
  53. def _run(self, cpn_input: ComponentInputProtocol):
  54. self.parameters = cpn_input.parameters
  55. self.job_parameters = cpn_input.job_parameters
  56. output_storage_address = self.job_parameters.engines_address[EngineType.STORAGE]
  57. # only support one input table
  58. table_key = [key for key in self.parameters.keys()][0]
  59. input_table_namespace, input_table_name = self.get_input_table_info(
  60. parameters=self.parameters[table_key],
  61. role=self.tracker.role,
  62. party_id=self.tracker.party_id,
  63. )
  64. (
  65. output_table_namespace,
  66. output_table_name,
  67. ) = default_output_info(
  68. task_id=self.tracker.task_id,
  69. task_version=self.tracker.task_version,
  70. output_type="data",
  71. )
  72. (
  73. input_table_meta,
  74. output_table_address,
  75. output_table_engine,
  76. ) = self.convert_check(
  77. input_name=input_table_name,
  78. input_namespace=input_table_namespace,
  79. output_name=output_table_name,
  80. output_namespace=output_table_namespace,
  81. computing_engine=self.job_parameters.computing_engine,
  82. output_storage_address=output_storage_address,
  83. )
  84. sess = Session.get_global()
  85. input_table = sess.get_table(
  86. name=input_table_meta.get_name(), namespace=input_table_meta.get_namespace()
  87. )
  88. # update real count to meta info
  89. input_table.count()
  90. # Table replication is required
  91. if input_table_meta.get_engine() != output_table_engine:
  92. LOGGER.info(
  93. f"the {input_table_meta.get_engine()} engine input table needs to be converted to {output_table_engine} engine to support computing engine {self.job_parameters.computing_engine}"
  94. )
  95. else:
  96. LOGGER.info(
  97. f"the {input_table_meta.get_engine()} input table needs to be transform format"
  98. )
  99. LOGGER.info("reader create storage session2")
  100. output_table_session = sess.storage(storage_engine=output_table_engine)
  101. output_table = output_table_session.create_table(
  102. address=output_table_address,
  103. name=output_table_name,
  104. namespace=output_table_namespace,
  105. partitions=input_table_meta.partitions,
  106. origin=StorageTableOrigin.READER
  107. )
  108. self.save_table(src_table=input_table, dest_table=output_table)
  109. # update real count to meta info
  110. output_table_meta = StorageTableMeta(
  111. name=output_table.name, namespace=output_table.namespace
  112. )
  113. # todo: may be set output data, and executor support pass persistent
  114. self.tracker.log_output_data_info(
  115. data_name=cpn_input.flow_feeded_parameters.get("output_data_name")[0]
  116. if cpn_input.flow_feeded_parameters.get("output_data_name")
  117. else table_key,
  118. table_namespace=output_table_meta.get_namespace(),
  119. table_name=output_table_meta.get_name(),
  120. )
  121. DataTableTracker.create_table_tracker(
  122. output_table_meta.get_name(),
  123. output_table_meta.get_namespace(),
  124. entity_info={
  125. "have_parent": True,
  126. "parent_table_namespace": input_table_namespace,
  127. "parent_table_name": input_table_name,
  128. "job_id": self.tracker.job_id,
  129. },
  130. )
  131. table_info, anonymous_info, attribute_info = self.data_info_display(output_table_meta)
  132. data_info = {
  133. "table_name": input_table_name,
  134. "namespace": input_table_namespace,
  135. "table_info": table_info,
  136. "anonymous_info": anonymous_info,
  137. "attribute_info": attribute_info,
  138. "partitions": output_table_meta.get_partitions(),
  139. "storage_engine": output_table_meta.get_engine(),
  140. }
  141. if input_table_meta.get_engine() in [StorageEngine.PATH]:
  142. data_info["file_count"] = output_table_meta.get_count()
  143. data_info["file_path"] = input_table_meta.get_address().path
  144. else:
  145. data_info["count"] = output_table_meta.get_count()
  146. self.tracker.set_metric_meta(
  147. metric_namespace="reader_namespace",
  148. metric_name="reader_name",
  149. metric_meta=MetricMeta(
  150. name="reader", metric_type="data_info", extra_metas=data_info
  151. ),
  152. )
  153. @staticmethod
  154. def get_input_table_info(parameters, role, party_id):
  155. search_type = data_utils.get_input_search_type(parameters)
  156. if search_type is InputSearchType.TABLE_INFO:
  157. return parameters["namespace"], parameters["name"]
  158. elif search_type is InputSearchType.JOB_COMPONENT_OUTPUT:
  159. output_data_infos = Tracker.query_output_data_infos(
  160. job_id=parameters["job_id"],
  161. component_name=parameters["component_name"],
  162. data_name=parameters["data_name"],
  163. role=role,
  164. party_id=party_id,
  165. )
  166. if not output_data_infos:
  167. raise Exception(f"can not found input table, please check parameters")
  168. else:
  169. namespace, name = (
  170. output_data_infos[0].f_table_namespace,
  171. output_data_infos[0].f_table_name,
  172. )
  173. LOGGER.info(f"found input table {namespace} {name} by {parameters}")
  174. return namespace, name
  175. else:
  176. raise ParameterError(
  177. f"can not found input table info by parameters {parameters}"
  178. )
  179. @staticmethod
  180. def convert_check(
  181. input_name,
  182. input_namespace,
  183. output_name,
  184. output_namespace,
  185. computing_engine: ComputingEngine = ComputingEngine.EGGROLL,
  186. output_storage_address={},
  187. ) -> (StorageTableMetaABC, AddressABC, StorageEngine):
  188. return data_utils.convert_output(input_name, input_namespace, output_name, output_namespace, computing_engine,
  189. output_storage_address)
  190. def save_table(self, src_table: StorageTableABC, dest_table: StorageTableABC):
  191. LOGGER.info(f"start copying table")
  192. LOGGER.info(
  193. f"source table name: {src_table.name} namespace: {src_table.namespace} engine: {src_table.engine}"
  194. )
  195. LOGGER.info(
  196. f"destination table name: {dest_table.name} namespace: {dest_table.namespace} engine: {dest_table.engine}"
  197. )
  198. if src_table.engine == dest_table.engine and src_table.meta.get_in_serialized():
  199. self.to_save(src_table, dest_table)
  200. else:
  201. TableStorage.copy_table(src_table, dest_table)
  202. # update anonymous
  203. self.create_anonymous(src_meta=src_table.meta, dest_meta=dest_table.meta)
  204. def to_save(self, src_table, dest_table):
  205. src_table_meta = src_table.meta
  206. src_computing_table = session.get_computing_session().load(
  207. src_table_meta.get_address(),
  208. schema=src_table_meta.get_schema(),
  209. partitions=src_table_meta.get_partitions(),
  210. id_delimiter=src_table_meta.get_id_delimiter(),
  211. in_serialized=src_table_meta.get_in_serialized(),
  212. )
  213. schema = src_table_meta.get_schema()
  214. self.tracker.job_tracker.save_output_data(
  215. src_computing_table,
  216. output_storage_engine=dest_table.engine,
  217. output_storage_address=dest_table.address.__dict__,
  218. output_table_namespace=dest_table.namespace,
  219. output_table_name=dest_table.name,
  220. schema=schema,
  221. need_read=False
  222. )
  223. schema = self.update_anonymous(computing_table=src_computing_table,schema=schema, src_table_meta=src_table_meta)
  224. LOGGER.info(f"dest schema: {schema}")
  225. dest_table.meta.update_metas(
  226. schema=schema,
  227. part_of_data=src_table_meta.get_part_of_data(),
  228. count=src_table_meta.get_count(),
  229. id_delimiter=src_table_meta.get_id_delimiter()
  230. )
  231. LOGGER.info(
  232. f"save {dest_table.namespace} {dest_table.name} success"
  233. )
  234. return src_computing_table
  235. def update_anonymous(self, computing_table, schema, src_table_meta):
  236. if schema.get("meta"):
  237. if "anonymous_header" not in schema:
  238. schema.update(AnonymousGenerator.generate_header(computing_table, schema))
  239. schema = AnonymousGenerator.generate_anonymous_header(schema=schema)
  240. src_table_meta.update_metas(schema=schema)
  241. schema = AnonymousGenerator.update_anonymous_header_with_role(schema, self.tracker.role, self.tracker.party_id)
  242. return schema
  243. def create_anonymous(self, src_meta, dest_meta):
  244. src_schema = src_meta.get_schema()
  245. dest_schema = dest_meta.get_schema()
  246. LOGGER.info(f"src schema: {src_schema}, dest schema {dest_schema}")
  247. if src_schema.get("meta"):
  248. if "anonymous_header" not in src_schema:
  249. LOGGER.info("start to create anonymous")
  250. dest_computing_table = session.get_computing_session().load(
  251. dest_meta.get_address(),
  252. schema=dest_meta.get_schema(),
  253. partitions=dest_meta.get_partitions(),
  254. id_delimiter=dest_meta.get_id_delimiter(),
  255. in_serialized=dest_meta.get_in_serialized(),
  256. )
  257. src_schema.update(AnonymousGenerator.generate_header(dest_computing_table, src_schema))
  258. dest_schema.update(AnonymousGenerator.generate_header(dest_computing_table, dest_schema))
  259. src_schema = AnonymousGenerator.generate_anonymous_header(schema=src_schema)
  260. dest_schema = AnonymousGenerator.generate_anonymous_header(schema=dest_schema)
  261. dest_schema = AnonymousGenerator.update_anonymous_header_with_role(dest_schema, self.tracker.role,
  262. self.tracker.party_id)
  263. LOGGER.info(f"update src schema {src_schema} and dest schema {dest_schema}")
  264. src_meta.update_metas(schema=src_schema)
  265. dest_meta.update_metas(schema=dest_schema)
  266. else:
  267. dest_schema = AnonymousGenerator.update_anonymous_header_with_role(dest_schema, self.tracker.role,
  268. self.tracker.party_id)
  269. LOGGER.info(f"update dest schema {dest_schema}")
  270. dest_meta.update_metas(schema=dest_schema)
  271. @staticmethod
  272. def data_info_display(output_table_meta):
  273. headers = output_table_meta.get_schema().get("header")
  274. schema = output_table_meta.get_schema()
  275. table_info = {}
  276. anonymous_info = {}
  277. attribute_info = {}
  278. try:
  279. if schema and headers:
  280. if schema.get("original_index_info"):
  281. data_list = [AnonymousGenerator.reconstruct_header(schema)]
  282. else:
  283. if isinstance(headers, str):
  284. data_list = [headers.split(",")]
  285. else:
  286. data_list = [[schema.get("label_name")] if schema.get("label_name") else []]
  287. data_list[0].extend(headers)
  288. LOGGER.info(f"data info header: {data_list[0]}")
  289. for data in output_table_meta.get_part_of_data():
  290. delimiter = schema.get("meta", {}).get("delimiter") or output_table_meta.id_delimiter
  291. data_list.append(data[1].split(delimiter))
  292. data = np.array(data_list)
  293. Tdata = data.transpose()
  294. for data in Tdata:
  295. table_info[data[0]] = ",".join(list(set(data[1:]))[:5])
  296. if schema and schema.get("anonymous_header"):
  297. anonymous_info = dict(zip(schema.get("header"), schema.get("anonymous_header")))
  298. attribute_info = dict(zip(schema.get("header"), ["feature"] * len(schema.get("header"))))
  299. if schema.get("label_name"):
  300. anonymous_info[schema.get("label_name")] = schema.get("anonymous_label")
  301. attribute_info[schema.get("label_name")] = "label"
  302. if schema.get("meta").get("id_list"):
  303. for id_name in schema.get("meta").get("id_list"):
  304. if id_name in attribute_info:
  305. attribute_info[id_name] = "match_id"
  306. except Exception as e:
  307. LOGGER.exception(e)
  308. return table_info, anonymous_info, attribute_info