_table.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import io
  17. import os
  18. from typing import Iterable
  19. from pyarrow import fs
  20. from fate_arch.common import hdfs_utils
  21. from fate_arch.common.log import getLogger
  22. from fate_arch.storage import StorageEngine, LocalFSStoreType
  23. from fate_arch.storage import StorageTableBase
  24. LOGGER = getLogger()
  25. class StorageTable(StorageTableBase):
  26. def __init__(
  27. self,
  28. address=None,
  29. name: str = None,
  30. namespace: str = None,
  31. partitions: int = 1,
  32. storage_type: LocalFSStoreType = LocalFSStoreType.DISK,
  33. options=None,
  34. ):
  35. super(StorageTable, self).__init__(
  36. name=name,
  37. namespace=namespace,
  38. address=address,
  39. partitions=partitions,
  40. options=options,
  41. engine=StorageEngine.LOCALFS,
  42. store_type=storage_type,
  43. )
  44. self._local_fs_client = fs.LocalFileSystem()
  45. @property
  46. def path(self):
  47. return self._address.path
  48. def _put_all(
  49. self, kv_list: Iterable, append=True, assume_file_exist=False, **kwargs
  50. ):
  51. LOGGER.info(f"put in file: {self.path}")
  52. # always create the directory first, otherwise the following creation of file will fail.
  53. self._local_fs_client.create_dir("/".join(self.path.split("/")[:-1]))
  54. if append and (assume_file_exist or self._exist()):
  55. stream = self._local_fs_client.open_append_stream(
  56. path=self.path, compression=None
  57. )
  58. else:
  59. stream = self._local_fs_client.open_output_stream(
  60. path=self.path, compression=None
  61. )
  62. counter = self._meta.get_count() if self._meta.get_count() else 0
  63. with io.TextIOWrapper(stream) as writer:
  64. for k, v in kv_list:
  65. writer.write(hdfs_utils.serialize(k, v))
  66. writer.write(hdfs_utils.NEWLINE)
  67. counter = counter + 1
  68. self._meta.update_metas(count=counter)
  69. def _collect(self, **kwargs) -> list:
  70. for line in self._as_generator():
  71. yield hdfs_utils.deserialize(line.rstrip())
  72. def _read(self) -> list:
  73. for line in self._as_generator():
  74. yield line
  75. def _destroy(self):
  76. # use try/catch to avoid stop while deleting an non-exist file
  77. try:
  78. self._local_fs_client.delete_file(self.path)
  79. except Exception as e:
  80. LOGGER.debug(e)
  81. def _count(self):
  82. count = 0
  83. for _ in self._as_generator():
  84. count += 1
  85. return count
  86. def _save_as(
  87. self, address, partitions=None, name=None, namespace=None, **kwargs
  88. ):
  89. self._local_fs_client.copy_file(src=self.path, dst=address.path)
  90. return StorageTable(
  91. address=address,
  92. partitions=partitions,
  93. name=name,
  94. namespace=namespace,
  95. **kwargs,
  96. )
  97. def close(self):
  98. pass
  99. def _exist(self):
  100. info = self._local_fs_client.get_file_info([self.path])[0]
  101. return info.type != fs.FileType.NotFound
  102. def _as_generator(self):
  103. info = self._local_fs_client.get_file_info([self.path])[0]
  104. if info.type == fs.FileType.NotFound:
  105. raise FileNotFoundError(f"file {self.path} not found")
  106. elif info.type == fs.FileType.File:
  107. for line in self._read_buffer_lines():
  108. yield line
  109. else:
  110. selector = fs.FileSelector(self.path)
  111. file_infos = self._local_fs_client.get_file_info(selector)
  112. for file_info in file_infos:
  113. if file_info.base_name.startswith(".") or file_info.base_name.startswith("_"):
  114. continue
  115. assert (
  116. file_info.is_file
  117. ), f"{self.path} is directory contains a subdirectory: {file_info.path}"
  118. with io.TextIOWrapper(
  119. buffer=self._local_fs_client.open_input_stream(
  120. f"{self._address.file_path:}/{file_info.path}"
  121. ),
  122. encoding="utf-8",
  123. ) as reader:
  124. for line in reader:
  125. yield line
  126. def _read_buffer_lines(self, path=None):
  127. if not path:
  128. path = self.path
  129. buffer = self._local_fs_client.open_input_file(self.path)
  130. offset = 0
  131. block_size = 1024 * 1024 * 10
  132. size = buffer.size()
  133. while offset < size:
  134. block_index = 1
  135. buffer_block = buffer.read_at(block_size, offset)
  136. if offset + block_size >= size:
  137. for line in self._read_lines(buffer_block):
  138. yield line
  139. break
  140. if buffer_block.endswith(b"\n"):
  141. for line in self._read_lines(buffer_block):
  142. yield line
  143. offset += block_size
  144. continue
  145. end_index = -1
  146. buffer_len = len(buffer_block)
  147. while not buffer_block[:end_index].endswith(b"\n"):
  148. if offset + block_index * block_size >= size:
  149. break
  150. end_index -= 1
  151. if abs(end_index) == buffer_len:
  152. block_index += 1
  153. buffer_block = buffer.read_at(block_index * block_size, offset)
  154. end_index = block_index * block_size
  155. for line in self._read_lines(buffer_block[:end_index]):
  156. yield line
  157. offset += len(buffer_block[:end_index])
  158. def _read_lines(self, buffer_block):
  159. with io.TextIOWrapper(buffer=io.BytesIO(buffer_block), encoding="utf-8") as reader:
  160. for line in reader:
  161. yield line