_table.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from contextlib import closing
  17. import requests
  18. import os
  19. from fate_arch.common.log import getLogger
  20. from fate_arch.storage import StorageEngine, ApiStoreType
  21. from fate_arch.storage import StorageTableBase
  22. LOGGER = getLogger()
  23. class StorageTable(StorageTableBase):
  24. def __init__(
  25. self,
  26. path,
  27. address=None,
  28. name: str = None,
  29. namespace: str = None,
  30. partitions: int = None,
  31. store_type: ApiStoreType = ApiStoreType.EXTERNAL,
  32. options=None,
  33. ):
  34. self.path = path
  35. self.data_count = 0
  36. super(StorageTable, self).__init__(
  37. name=name,
  38. namespace=namespace,
  39. address=address,
  40. partitions=partitions,
  41. options=options,
  42. engine=StorageEngine.API,
  43. store_type=store_type,
  44. )
  45. def _collect(self, **kwargs) -> list:
  46. self.request = getattr(requests, self.address.method.lower(), None)
  47. id_delimiter = self._meta.get_id_delimiter()
  48. with closing(self.request(url=self.address.url, json=self.address.body, headers=self.address.header,
  49. stream=True)) as response:
  50. if response.status_code == 200:
  51. os.makedirs(os.path.dirname(self.path), exist_ok=True)
  52. with open(self.path, 'wb') as fw:
  53. for chunk in response.iter_content(1024):
  54. if chunk:
  55. fw.write(chunk)
  56. with open(self.path, "r") as f:
  57. while True:
  58. lines = f.readlines(1024 * 1024 * 1024)
  59. if lines:
  60. for line in lines:
  61. self.data_count += 1
  62. id = line.split(id_delimiter)[0]
  63. feature = id_delimiter.join(line.split(id_delimiter)[1:])
  64. yield id, feature
  65. else:
  66. _, self._meta = self._meta.update_metas(count=self.data_count)
  67. break
  68. else:
  69. raise Exception(response.status_code, response.text)
  70. def _read(self) -> list:
  71. return []
  72. def _destroy(self):
  73. pass
  74. def _save_as(self, **kwargs):
  75. pass
  76. def _count(self):
  77. return self.data_count