upload.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import os
  17. import shutil
  18. import sys
  19. import time
  20. import uuid
  21. from fate_arch import storage, session
  22. from fate_arch.common import EngineType, log, path_utils
  23. from fate_arch.common.data_utils import default_input_fs_path
  24. from fate_arch.session import Session
  25. from fate_arch.storage import DEFAULT_ID_DELIMITER, EggRollStoreType, StorageEngine, StorageTableOrigin
  26. from fate_flow.components._base import (
  27. BaseParam,
  28. ComponentBase,
  29. ComponentMeta,
  30. ComponentInputProtocol,
  31. )
  32. from fate_flow.components.param_extract import ParamExtract
  33. from fate_flow.entity import Metric, MetricMeta, MetricType
  34. from fate_flow.manager.data_manager import DataTableTracker, AnonymousGenerator, SchemaMetaParam
  35. from fate_flow.scheduling_apps.client import ControllerClient
  36. from fate_flow.db.job_default_config import JobDefaultConfig
  37. from fate_flow.utils import data_utils, job_utils
  38. from fate_flow.utils.base_utils import get_fate_flow_directory
  39. LOGGER = log.getLogger()
  40. upload_cpn_meta = ComponentMeta("Upload")
  41. @upload_cpn_meta.bind_param
  42. class UploadParam(BaseParam):
  43. def __init__(
  44. self,
  45. file="",
  46. head=1,
  47. id_delimiter=DEFAULT_ID_DELIMITER,
  48. partition=10,
  49. namespace="",
  50. name="",
  51. storage_engine="",
  52. storage_address=None,
  53. destroy=False,
  54. extend_sid=False,
  55. auto_increasing_sid=False,
  56. block_size=1,
  57. schema=None,
  58. # extra param
  59. with_meta=False,
  60. meta={}
  61. ):
  62. self.file = file
  63. self.head = head
  64. self.id_delimiter = id_delimiter
  65. self.partition = partition
  66. self.namespace = namespace
  67. self.name = name
  68. self.storage_engine = storage_engine
  69. self.storage_address = storage_address
  70. self.destroy = destroy
  71. self.extend_sid = extend_sid
  72. self.auto_increasing_sid = auto_increasing_sid
  73. self.block_size = block_size
  74. self.schema = schema if schema else {}
  75. # extra param
  76. self.with_meta = with_meta
  77. self.meta = meta
  78. def check(self):
  79. return True
  80. def update(self, conf, allow_redundant=False):
  81. LOGGER.info(f"update:{conf}")
  82. params = ParamExtract().recursive_parse_param_from_config(
  83. param=self,
  84. config_json=conf,
  85. param_parse_depth=0,
  86. valid_check=not allow_redundant,
  87. name=self._name,
  88. )
  89. params.update_meta(params)
  90. LOGGER.info(f"update result:{params.__dict__}")
  91. return params
  92. @staticmethod
  93. def update_meta(params):
  94. if params.with_meta:
  95. _meta = SchemaMetaParam(**params.meta).to_dict()
  96. if params.extend_sid:
  97. _meta["with_match_id"] = True
  98. else:
  99. _meta = {}
  100. params.meta = _meta
  101. return params
  102. @upload_cpn_meta.bind_runner.on_local
  103. class Upload(ComponentBase):
  104. def __init__(self):
  105. super(Upload, self).__init__()
  106. self.MAX_PARTITIONS = 1024
  107. self.MAX_BYTES = 1024 * 1024 * 8 * 500
  108. self.parameters = {}
  109. self.table = None
  110. self.is_block = False
  111. self.session_id = None
  112. self.session = None
  113. self.storage_engine = None
  114. def _run(self, cpn_input: ComponentInputProtocol):
  115. self.parameters = cpn_input.parameters
  116. LOGGER.info(self.parameters)
  117. self.parameters["role"] = cpn_input.roles["role"]
  118. self.parameters["local"] = cpn_input.roles["local"]
  119. storage_engine = self.parameters["storage_engine"].upper()
  120. storage_address = self.parameters["storage_address"]
  121. # if not set storage, use job storage as default
  122. if not storage_engine:
  123. storage_engine = cpn_input.job_parameters.storage_engine
  124. self.storage_engine = storage_engine
  125. if not storage_address:
  126. storage_address = cpn_input.job_parameters.engines_address[
  127. EngineType.STORAGE
  128. ]
  129. job_id = self.task_version_id.split("_")[0]
  130. if not os.path.isabs(self.parameters.get("file", "")):
  131. self.parameters["file"] = os.path.join(
  132. get_fate_flow_directory(), self.parameters["file"]
  133. )
  134. if not os.path.exists(self.parameters["file"]):
  135. raise Exception(
  136. "%s is not exist, please check the configure"
  137. % (self.parameters["file"])
  138. )
  139. if not os.path.getsize(self.parameters["file"]):
  140. raise Exception("%s is an empty file" % (self.parameters["file"]))
  141. name, namespace = self.parameters.get("name"), self.parameters.get("namespace")
  142. _namespace, _table_name = self.generate_table_name(self.parameters["file"])
  143. if namespace is None:
  144. namespace = _namespace
  145. if name is None:
  146. name = _table_name
  147. if self.parameters.get("with_meta"):
  148. self.parameters["id_delimiter"] = self.parameters.get("meta").get("delimiter")
  149. read_head = self.parameters["head"]
  150. if read_head == 0:
  151. head = False
  152. elif read_head == 1:
  153. head = True
  154. else:
  155. raise Exception("'head' in conf.json should be 0 or 1")
  156. partitions = self.parameters["partition"]
  157. if partitions <= 0 or partitions >= self.MAX_PARTITIONS:
  158. raise Exception(
  159. "Error number of partition, it should between %d and %d"
  160. % (0, self.MAX_PARTITIONS)
  161. )
  162. self.session_id = job_utils.generate_session_id(
  163. self.tracker.task_id,
  164. self.tracker.task_version,
  165. self.tracker.role,
  166. self.tracker.party_id,
  167. )
  168. sess = Session.get_global()
  169. self.session = sess
  170. if self.parameters.get("destroy", False):
  171. table = sess.get_table(namespace=namespace, name=name)
  172. if table:
  173. LOGGER.info(
  174. f"destroy table name: {name} namespace: {namespace} engine: {table.engine}"
  175. )
  176. try:
  177. table.destroy()
  178. except Exception as e:
  179. LOGGER.error(e)
  180. else:
  181. LOGGER.info(
  182. f"can not found table name: {name} namespace: {namespace}, pass destroy"
  183. )
  184. address_dict = storage_address.copy()
  185. storage_session = sess.storage(
  186. storage_engine=storage_engine, options=self.parameters.get("options")
  187. )
  188. upload_address = {}
  189. if storage_engine in {StorageEngine.EGGROLL, StorageEngine.STANDALONE}:
  190. upload_address = {
  191. "name": name,
  192. "namespace": namespace,
  193. "storage_type": EggRollStoreType.ROLLPAIR_LMDB,
  194. }
  195. elif storage_engine in {StorageEngine.MYSQL, StorageEngine.HIVE}:
  196. if not address_dict.get("db") or not address_dict.get("name"):
  197. upload_address = {"db": namespace, "name": name}
  198. elif storage_engine in {StorageEngine.PATH}:
  199. upload_address = {"path": self.parameters["file"]}
  200. elif storage_engine in {StorageEngine.HDFS}:
  201. upload_address = {
  202. "path": default_input_fs_path(
  203. name=name,
  204. namespace=namespace,
  205. prefix=address_dict.get("path_prefix"),
  206. )
  207. }
  208. elif storage_engine in {StorageEngine.LOCALFS}:
  209. upload_address = {
  210. "path": default_input_fs_path(
  211. name=name,
  212. namespace=namespace,
  213. storage_engine=storage_engine
  214. )
  215. }
  216. else:
  217. raise RuntimeError(f"can not support this storage engine: {storage_engine}")
  218. address_dict.update(upload_address)
  219. LOGGER.info(f"upload to {storage_engine} storage, address: {address_dict}")
  220. address = storage.StorageTableMeta.create_address(
  221. storage_engine=storage_engine, address_dict=address_dict
  222. )
  223. self.parameters["partitions"] = partitions
  224. self.parameters["name"] = name
  225. self.table = storage_session.create_table(address=address, origin=StorageTableOrigin.UPLOAD, **self.parameters)
  226. if storage_engine not in [StorageEngine.PATH]:
  227. data_table_count = self.save_data_table(job_id, name, namespace, head)
  228. else:
  229. data_table_count = self.get_data_table_count(
  230. self.parameters["file"], name, namespace
  231. )
  232. self.table.meta.update_metas(in_serialized=True)
  233. DataTableTracker.create_table_tracker(
  234. table_name=name,
  235. table_namespace=namespace,
  236. entity_info={"job_id": job_id, "have_parent": False},
  237. )
  238. LOGGER.info("------------load data finish!-----------------")
  239. # rm tmp file
  240. try:
  241. if "{}/fate_upload_tmp".format(job_id) in self.parameters["file"]:
  242. LOGGER.info("remove tmp upload file")
  243. LOGGER.info(os.path.dirname(self.parameters["file"]))
  244. shutil.rmtree(os.path.dirname(self.parameters["file"]))
  245. except:
  246. LOGGER.info("remove tmp file failed")
  247. LOGGER.info("file: {}".format(self.parameters["file"]))
  248. LOGGER.info("total data_count: {}".format(data_table_count))
  249. LOGGER.info("table name: {}, table namespace: {}".format(name, namespace))
  250. def save_data_table(self, job_id, dst_table_name, dst_table_namespace, head=True):
  251. input_file = self.parameters["file"]
  252. input_feature_count = self.get_count(input_file)
  253. self.upload_file(input_file, head, job_id, input_feature_count)
  254. table_count = self.table.count()
  255. metas_info = {
  256. "count": table_count,
  257. "partitions": self.parameters["partition"],
  258. "extend_sid": self.parameters["extend_sid"]
  259. }
  260. if self.parameters.get("with_meta"):
  261. metas_info.update({"schema": self.generate_anonymous_schema()})
  262. self.table.meta.update_metas(**metas_info)
  263. self.save_meta(
  264. dst_table_namespace=dst_table_namespace,
  265. dst_table_name=dst_table_name,
  266. table_count=table_count,
  267. )
  268. return table_count
  269. @staticmethod
  270. def get_count(input_file):
  271. with open(input_file, "r", encoding="utf-8") as fp:
  272. count = 0
  273. for _ in fp:
  274. count += 1
  275. return count
  276. def upload_file(self, input_file, head, job_id=None, input_feature_count=None, table=None):
  277. if not table:
  278. table = self.table
  279. with open(input_file, "r") as fin:
  280. lines_count = 0
  281. if head is True:
  282. data_head = fin.readline()
  283. input_feature_count -= 1
  284. self.update_table_schema(data_head)
  285. else:
  286. self.update_table_schema()
  287. n = 0
  288. fate_uuid = uuid.uuid1().hex
  289. get_line = self.get_line()
  290. line_index = 0
  291. while True:
  292. data = list()
  293. lines = fin.readlines(JobDefaultConfig.upload_block_max_bytes)
  294. LOGGER.info(JobDefaultConfig.upload_block_max_bytes)
  295. if lines:
  296. # self.append_data_line(lines, data, n)
  297. for line in lines:
  298. values = line.rstrip().split(self.parameters["id_delimiter"])
  299. k, v = get_line(
  300. values=values,
  301. line_index=line_index,
  302. extend_sid=self.parameters["extend_sid"],
  303. auto_increasing_sid=self.parameters["auto_increasing_sid"],
  304. id_delimiter=self.parameters["id_delimiter"],
  305. fate_uuid=fate_uuid,
  306. )
  307. data.append((k, v))
  308. line_index += 1
  309. lines_count += len(data)
  310. save_progress = lines_count / input_feature_count * 100 // 1
  311. job_info = {
  312. "progress": save_progress,
  313. "job_id": job_id,
  314. "role": self.parameters["local"]["role"],
  315. "party_id": self.parameters["local"]["party_id"],
  316. }
  317. ControllerClient.update_job(job_info=job_info)
  318. table.put_all(data)
  319. if n == 0:
  320. table.meta.update_metas(part_of_data=data)
  321. else:
  322. return
  323. n += 1
  324. def get_computing_table(self, name, namespace, schema=None):
  325. storage_table_meta = storage.StorageTableMeta(name=name, namespace=namespace)
  326. computing_table = session.get_computing_session().load(
  327. storage_table_meta.get_address(),
  328. schema=schema if schema else storage_table_meta.get_schema(),
  329. partitions=self.parameters.get("partitions"))
  330. return computing_table
  331. def generate_anonymous_schema(self):
  332. computing_table = self.get_computing_table(self.table.name, self.table.namespace)
  333. LOGGER.info(f"computing table schema: {computing_table.schema}")
  334. schema = computing_table.schema
  335. if schema.get("meta"):
  336. schema.update(AnonymousGenerator.generate_header(computing_table, schema))
  337. schema = AnonymousGenerator.generate_anonymous_header(schema=schema)
  338. LOGGER.info(f"extra schema: {schema}")
  339. return schema
  340. def update_table_schema(self, data_head=""):
  341. LOGGER.info(f"data head: {data_head}")
  342. schema = data_utils.get_header_schema(
  343. header_line=data_head,
  344. id_delimiter=self.parameters["id_delimiter"],
  345. extend_sid=self.parameters["extend_sid"],
  346. )
  347. # update extra schema and meta info
  348. schema.update(self.parameters.get("schema", {}))
  349. schema.update({"meta": self.parameters.get("meta", {})})
  350. _, meta = self.table.meta.update_metas(
  351. schema=schema,
  352. auto_increasing_sid=self.parameters["auto_increasing_sid"],
  353. extend_sid=self.parameters["extend_sid"],
  354. )
  355. self.table.meta = meta
  356. def get_line(self):
  357. if not self.parameters["extend_sid"]:
  358. line = data_utils.get_data_line
  359. elif not self.parameters["auto_increasing_sid"]:
  360. line = data_utils.get_sid_data_line
  361. else:
  362. line = data_utils.get_auto_increasing_sid_data_line
  363. return line
  364. @staticmethod
  365. def generate_table_name(input_file_path):
  366. str_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
  367. file_name = input_file_path.split(".")[0]
  368. file_name = file_name.split("/")[-1]
  369. return file_name, str_time
  370. def save_meta(self, dst_table_namespace, dst_table_name, table_count):
  371. self.tracker.log_output_data_info(
  372. data_name="upload",
  373. table_namespace=dst_table_namespace,
  374. table_name=dst_table_name,
  375. )
  376. self.tracker.log_metric_data(
  377. metric_namespace="upload",
  378. metric_name="data_access",
  379. metrics=[Metric("count", table_count)],
  380. )
  381. self.tracker.set_metric_meta(
  382. metric_namespace="upload",
  383. metric_name="data_access",
  384. metric_meta=MetricMeta(name="upload", metric_type=MetricType.UPLOAD),
  385. )
  386. def get_data_table_count(self, path, name, namespace):
  387. count = path_utils.get_data_table_count(path)
  388. self.save_meta(
  389. dst_table_namespace=namespace, dst_table_name=name, table_count=count
  390. )
  391. self.table.meta.update_metas(count=count)
  392. return count