data_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import datetime
  18. import json
  19. import operator
  20. import os
  21. import tarfile
  22. import uuid
  23. from flask import send_file
  24. from fate_arch.abc import StorageTableABC
  25. from fate_arch.common.base_utils import fate_uuid
  26. from fate_arch.session import Session
  27. from fate_flow.component_env_utils import feature_utils, env_utils
  28. from fate_flow.settings import stat_logger
  29. from fate_flow.db.db_models import DB, TrackingMetric, DataTableTracking
  30. from fate_flow.utils import data_utils
  31. from fate_flow.utils.base_utils import get_fate_flow_directory
  32. from fate_flow.utils.data_utils import get_header_schema, line_extend_uuid
  33. class SchemaMetaParam:
  34. def __init__(self,
  35. delimiter=",",
  36. input_format="dense",
  37. tag_with_value=False,
  38. tag_value_delimiter=":",
  39. with_match_id=False,
  40. id_list=None,
  41. id_range=0,
  42. exclusive_data_type=None,
  43. data_type="float64",
  44. with_label=False,
  45. label_name="y",
  46. label_type="int"):
  47. self.input_format = input_format
  48. self.delimiter = delimiter
  49. self.tag_with_value = tag_with_value
  50. self.tag_value_delimiter = tag_value_delimiter
  51. self.with_match_id = with_match_id
  52. self.id_list = id_list
  53. self.id_range = id_range
  54. self.exclusive_data_type = exclusive_data_type
  55. self.data_type = data_type
  56. self.with_label = with_label
  57. self.label_name = label_name
  58. self.label_type = label_type
  59. self.adapter_param()
  60. def to_dict(self):
  61. d = {}
  62. for k, v in self.__dict__.items():
  63. if v is None:
  64. continue
  65. d[k] = v
  66. return d
  67. def adapter_param(self):
  68. if not self.with_label:
  69. self.label_name = None
  70. self.label_type = None
  71. class AnonymousGenerator(object):
  72. @staticmethod
  73. def update_anonymous_header_with_role(schema, role, party_id):
  74. obj = env_utils.get_class_object("anonymous_generator")
  75. return obj.update_anonymous_header_with_role(schema, role, party_id)
  76. @staticmethod
  77. def generate_anonymous_header(schema):
  78. obj = env_utils.get_class_object("anonymous_generator")()
  79. return obj.generate_anonymous_header(schema)
  80. @staticmethod
  81. def migrate_schema_anonymous(anonymous_schema, role, party_id, migrate_mapping):
  82. obj = env_utils.get_class_object("anonymous_generator")(role, party_id, migrate_mapping)
  83. return obj.migrate_schema_anonymous(anonymous_schema)
  84. @staticmethod
  85. def generate_header(computing_table, schema):
  86. obj = env_utils.get_class_object("data_format")
  87. return obj.generate_header(computing_table, schema)
  88. @staticmethod
  89. def reconstruct_header(schema):
  90. obj = env_utils.get_class_object("data_format")
  91. return obj.reconstruct_header(schema)
  92. @staticmethod
  93. def recover_schema(schema):
  94. obj = env_utils.get_class_object("data_format")
  95. return obj.recover_schema(schema)
  96. class DataTableTracker(object):
  97. @classmethod
  98. @DB.connection_context()
  99. def create_table_tracker(cls, table_name, table_namespace, entity_info):
  100. tracker = DataTableTracking()
  101. tracker.f_table_name = table_name
  102. tracker.f_table_namespace = table_namespace
  103. for k, v in entity_info.items():
  104. attr_name = 'f_%s' % k
  105. if hasattr(DataTableTracking, attr_name):
  106. setattr(tracker, attr_name, v)
  107. if entity_info.get("have_parent"):
  108. parent_trackers = DataTableTracking.select().where(
  109. DataTableTracking.f_table_name == entity_info.get("parent_table_name"),
  110. DataTableTracking.f_table_namespace == entity_info.get("parent_table_namespace")).order_by(DataTableTracking.f_create_time.desc())
  111. if not parent_trackers:
  112. tracker.f_source_table_name = entity_info.get("parent_table_name")
  113. tracker.f_source_table_namespace = entity_info.get("parent_table_namespace")
  114. else:
  115. parent_tracker = parent_trackers[0]
  116. if parent_tracker.f_have_parent:
  117. tracker.f_source_table_name = parent_tracker.f_source_table_name
  118. tracker.f_source_table_namespace = parent_tracker.f_source_table_namespace
  119. else:
  120. tracker.f_source_table_name = parent_tracker.f_table_name
  121. tracker.f_source_table_namespace = parent_tracker.f_table_namespace
  122. rows = tracker.save(force_insert=True)
  123. if rows != 1:
  124. raise Exception("Create {} failed".format(tracker))
  125. return tracker
  126. @classmethod
  127. @DB.connection_context()
  128. def query_tracker(cls, table_name, table_namespace, is_parent=False):
  129. if not is_parent:
  130. filters = [operator.attrgetter('f_table_name')(DataTableTracking) == table_name,
  131. operator.attrgetter('f_table_namespace')(DataTableTracking) == table_namespace]
  132. else:
  133. filters = [operator.attrgetter('f_parent_table_name')(DataTableTracking) == table_name,
  134. operator.attrgetter('f_parent_table_namespace')(DataTableTracking) == table_namespace]
  135. trackers = DataTableTracking.select().where(*filters)
  136. return [tracker for tracker in trackers]
  137. @classmethod
  138. @DB.connection_context()
  139. def get_parent_table(cls, table_name, table_namespace):
  140. trackers = DataTableTracker.query_tracker(table_name, table_namespace)
  141. if not trackers:
  142. raise Exception(f"no found table: table name {table_name}, table namespace {table_namespace}")
  143. else:
  144. parent_table_info = []
  145. for tracker in trackers:
  146. if not tracker.f_have_parent:
  147. return []
  148. else:
  149. parent_table_info.append({"parent_table_name": tracker.f_parent_table_name,
  150. "parent_table_namespace": tracker.f_parent_table_namespace,
  151. "source_table_name": tracker.f_source_table_name,
  152. "source_table_namespace": tracker.f_source_table_namespace
  153. })
  154. return parent_table_info
  155. @classmethod
  156. @DB.connection_context()
  157. def track_job(cls, table_name, table_namespace, display=False):
  158. trackers = DataTableTracker.query_tracker(table_name, table_namespace, is_parent=True)
  159. job_id_list = []
  160. for tracker in trackers:
  161. job_id_list.append(tracker.f_job_id)
  162. job_id_list = list(set(job_id_list))
  163. return {"count": len(job_id_list)} if not display else {"count": len(job_id_list), "job": job_id_list}
  164. class TableStorage:
  165. @staticmethod
  166. def copy_table(src_table: StorageTableABC, dest_table: StorageTableABC, deserialize_value=False):
  167. count = 0
  168. data_temp = []
  169. part_of_data = []
  170. src_table_meta = src_table.meta
  171. schema = {}
  172. update_schema = False
  173. line_index = 0
  174. fate_uuid = uuid.uuid1().hex
  175. if not src_table_meta.get_in_serialized():
  176. if src_table_meta.get_have_head():
  177. get_head = False
  178. else:
  179. get_head = True
  180. if not src_table.meta.get_extend_sid():
  181. get_line = data_utils.get_data_line
  182. elif not src_table_meta.get_auto_increasing_sid():
  183. get_line = data_utils.get_sid_data_line
  184. else:
  185. get_line = data_utils.get_auto_increasing_sid_data_line
  186. for line in src_table.read():
  187. if not get_head:
  188. schema = data_utils.get_header_schema(
  189. header_line=line,
  190. id_delimiter=src_table_meta.get_id_delimiter(),
  191. extend_sid=src_table_meta.get_extend_sid(),
  192. )
  193. get_head = True
  194. continue
  195. values = line.rstrip().split(src_table.meta.get_id_delimiter())
  196. k, v = get_line(
  197. values=values,
  198. line_index=line_index,
  199. extend_sid=src_table.meta.get_extend_sid(),
  200. auto_increasing_sid=src_table.meta.get_auto_increasing_sid(),
  201. id_delimiter=src_table.meta.get_id_delimiter(),
  202. fate_uuid=fate_uuid,
  203. )
  204. line_index += 1
  205. count = TableStorage.put_in_table(
  206. table=dest_table,
  207. k=k,
  208. v=v,
  209. temp=data_temp,
  210. count=count,
  211. part_of_data=part_of_data,
  212. )
  213. else:
  214. source_header = copy.deepcopy(src_table_meta.get_schema().get("header"))
  215. TableStorage.update_full_header(src_table_meta)
  216. for k, v in src_table.collect():
  217. if src_table.meta.get_extend_sid():
  218. # extend id
  219. v = src_table.meta.get_id_delimiter().join([k, v])
  220. k = line_extend_uuid(fate_uuid, line_index)
  221. line_index += 1
  222. if deserialize_value:
  223. # writer component: deserialize value
  224. v, extend_header = feature_utils.get_deserialize_value(v, dest_table.meta.get_id_delimiter())
  225. if not update_schema:
  226. header_list = get_component_output_data_schema(src_table.meta, extend_header)
  227. schema = get_header_schema(dest_table.meta.get_id_delimiter().join(header_list),
  228. dest_table.meta.get_id_delimiter())
  229. _, dest_table.meta = dest_table.meta.update_metas(schema=schema)
  230. update_schema = True
  231. count = TableStorage.put_in_table(
  232. table=dest_table,
  233. k=k,
  234. v=v,
  235. temp=data_temp,
  236. count=count,
  237. part_of_data=part_of_data,
  238. )
  239. schema = src_table.meta.get_schema()
  240. schema["header"] = source_header
  241. if data_temp:
  242. dest_table.put_all(data_temp)
  243. if schema.get("extend_tag"):
  244. schema.update({"extend_tag": False})
  245. _, dest_table.meta = dest_table.meta.update_metas(schema=schema if not update_schema else None, part_of_data=part_of_data)
  246. return dest_table.count()
  247. @staticmethod
  248. def update_full_header(table_meta):
  249. schema = table_meta.get_schema()
  250. if schema.get("anonymous_header"):
  251. header = AnonymousGenerator.reconstruct_header(schema)
  252. schema["header"] = header
  253. table_meta.set_metas(schema=schema)
  254. @staticmethod
  255. def put_in_table(table: StorageTableABC, k, v, temp, count, part_of_data, max_num=10000):
  256. temp.append((k, v))
  257. if count < 100:
  258. part_of_data.append((k, v))
  259. if len(temp) == max_num:
  260. table.put_all(temp)
  261. temp.clear()
  262. return count + 1
  263. @staticmethod
  264. def send_table(output_tables_meta, tar_file_name="", limit=-1, need_head=True, local_download=False, output_data_file_path=None):
  265. output_data_file_list = []
  266. output_data_meta_file_list = []
  267. output_tmp_dir = os.path.join(get_fate_flow_directory(), 'tmp/{}/{}'.format(datetime.datetime.now().strftime("%Y%m%d"), fate_uuid()))
  268. for output_name, output_table_meta in output_tables_meta.items():
  269. output_data_count = 0
  270. if not local_download:
  271. output_data_file_path = "{}/{}.csv".format(output_tmp_dir, output_name)
  272. output_data_meta_file_path = "{}/{}.meta".format(output_tmp_dir, output_name)
  273. os.makedirs(os.path.dirname(output_data_file_path), exist_ok=True)
  274. with open(output_data_file_path, 'w') as fw:
  275. with Session() as sess:
  276. output_table = sess.get_table(name=output_table_meta.get_name(),
  277. namespace=output_table_meta.get_namespace())
  278. all_extend_header = {}
  279. if output_table:
  280. for k, v in output_table.collect():
  281. data_line, is_str, all_extend_header = feature_utils.get_component_output_data_line(
  282. src_key=k,
  283. src_value=v,
  284. schema=output_table_meta.get_schema(),
  285. all_extend_header=all_extend_header)
  286. # save meta
  287. if output_data_count == 0:
  288. output_data_file_list.append(output_data_file_path)
  289. extend_header = feature_utils.generate_header(all_extend_header,
  290. schema=output_table_meta.get_schema())
  291. header = get_component_output_data_schema(output_table_meta=output_table_meta,
  292. is_str=is_str,
  293. extend_header=extend_header)
  294. if not local_download:
  295. output_data_meta_file_list.append(output_data_meta_file_path)
  296. with open(output_data_meta_file_path, 'w') as f:
  297. json.dump({'header': header}, f, indent=4)
  298. if need_head and header and output_table_meta.get_have_head() and \
  299. output_table_meta.get_schema().get("is_display", True):
  300. fw.write('{}\n'.format(','.join(header)))
  301. delimiter = output_table_meta.get_id_delimiter() if output_table_meta.get_id_delimiter() else ","
  302. fw.write('{}\n'.format(delimiter.join(map(lambda x: str(x), data_line))))
  303. output_data_count += 1
  304. if output_data_count == limit:
  305. break
  306. if local_download:
  307. return
  308. # tar
  309. output_data_tarfile = "{}/{}".format(output_tmp_dir, tar_file_name)
  310. tar = tarfile.open(output_data_tarfile, mode='w:gz')
  311. for index in range(0, len(output_data_file_list)):
  312. tar.add(output_data_file_list[index], os.path.relpath(output_data_file_list[index], output_tmp_dir))
  313. tar.add(output_data_meta_file_list[index],
  314. os.path.relpath(output_data_meta_file_list[index], output_tmp_dir))
  315. tar.close()
  316. for key, path in enumerate(output_data_file_list):
  317. try:
  318. os.remove(path)
  319. os.remove(output_data_meta_file_list[key])
  320. except Exception as e:
  321. # warning
  322. stat_logger.warning(e)
  323. return send_file(output_data_tarfile, attachment_filename=tar_file_name, as_attachment=True)
  324. def delete_tables_by_table_infos(output_data_table_infos):
  325. data = []
  326. status = False
  327. with Session() as sess:
  328. for output_data_table_info in output_data_table_infos:
  329. table_name = output_data_table_info.f_table_name
  330. namespace = output_data_table_info.f_table_namespace
  331. table_info = {'table_name': table_name, 'namespace': namespace}
  332. if table_name and namespace and table_info not in data:
  333. table = sess.get_table(table_name, namespace)
  334. if table:
  335. try:
  336. table.destroy()
  337. data.append(table_info)
  338. status = True
  339. except Exception as e:
  340. stat_logger.warning(e)
  341. return status, data
  342. def delete_metric_data(metric_info):
  343. status = delete_metric_data_from_db(metric_info)
  344. return f"delete status: {status}"
  345. @DB.connection_context()
  346. def delete_metric_data_from_db(metric_info):
  347. tracking_metric_model = type(TrackingMetric.model(table_index=metric_info.get("job_id")[:8]))
  348. operate = tracking_metric_model.delete().where(*get_delete_filters(tracking_metric_model, metric_info))
  349. return operate.execute() > 0
  350. def get_delete_filters(tracking_metric_model, metric_info):
  351. delete_filters = []
  352. primary_keys = ["job_id", "role", "party_id", "component_name"]
  353. for key in primary_keys:
  354. if key in metric_info:
  355. delete_filters.append(operator.attrgetter("f_%s" % key)(tracking_metric_model) == metric_info[key])
  356. return delete_filters
  357. def get_component_output_data_schema(output_table_meta, extend_header, is_str=False) -> list:
  358. # get schema
  359. schema = output_table_meta.get_schema()
  360. if not schema:
  361. return []
  362. header = [schema.get('sid_name') or schema.get('sid', 'sid')]
  363. if schema.get("extend_tag"):
  364. header = []
  365. if "label" in extend_header and schema.get("label_name"):
  366. extend_header[extend_header.index("label")] = schema.get("label_name")
  367. header.extend(extend_header)
  368. if is_str or isinstance(schema.get('header'), str):
  369. if schema.get("original_index_info"):
  370. header = [schema.get('sid_name') or schema.get('sid', 'sid')]
  371. header.extend(AnonymousGenerator.reconstruct_header(schema))
  372. return header
  373. if not schema.get('header'):
  374. if schema.get('sid'):
  375. return [schema.get('sid')]
  376. else:
  377. return []
  378. if isinstance(schema.get('header'), str):
  379. schema_header = schema.get('header').split(',')
  380. elif isinstance(schema.get('header'), list):
  381. schema_header = schema.get('header')
  382. else:
  383. raise ValueError("header type error")
  384. header.extend([feature for feature in schema_header])
  385. else:
  386. header.extend(schema.get('header', []))
  387. return header