download.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. These codes are adopted from torchvison with some modifications.
  3. """
  4. import gzip
  5. import hashlib
  6. import logging
  7. import os
  8. import tarfile
  9. import zipfile
  10. import requests
  11. from tqdm import tqdm
  12. logger = logging.getLogger(__name__)
  13. def gen_bar_updater():
  14. pbar = tqdm(total=None)
  15. def bar_update(count, block_size, total_size):
  16. if pbar.total is None and total_size:
  17. pbar.total = total_size
  18. progress_bytes = count * block_size
  19. pbar.update(progress_bytes - pbar.n)
  20. return bar_update
  21. def calculate_md5(fpath, chunk_size=1024 * 1024):
  22. md5 = hashlib.md5()
  23. with open(fpath, 'rb') as f:
  24. for chunk in iter(lambda: f.read(chunk_size), b''):
  25. md5.update(chunk)
  26. return md5.hexdigest()
  27. def check_md5(fpath, md5, **kwargs):
  28. return md5 == calculate_md5(fpath, **kwargs)
  29. def check_integrity(fpath, md5=None):
  30. if not os.path.isfile(fpath):
  31. return False
  32. if md5 is None:
  33. return True
  34. return check_md5(fpath, md5)
  35. def download_url(url, root, filename=None, md5=None):
  36. """Download a file from a url and place it in root.
  37. Args:
  38. url (str): URL to download file from
  39. root (str): Directory to place downloaded file in
  40. filename (str, optional): Name to save the file under. If None, use the basename of the URL
  41. """
  42. import urllib.request
  43. import urllib.error
  44. root = os.path.expanduser(root)
  45. if not filename:
  46. filename = os.path.basename(url)
  47. fpath = os.path.join(root, filename)
  48. os.makedirs(root, exist_ok=True)
  49. # check if file is already present locally
  50. if check_integrity(fpath, md5):
  51. logger.info("Using downloaded and verified file: " + fpath)
  52. return fpath
  53. else: # download the file
  54. try:
  55. logger.info("Downloading {} to {}".format(url, fpath))
  56. urllib.request.urlretrieve(
  57. url, fpath,
  58. reporthook=gen_bar_updater()
  59. )
  60. except (urllib.error.URLError, IOError) as e:
  61. if url[:5] != 'https':
  62. raise e
  63. url = url.replace('https:', 'http:')
  64. logger.info("Failed download. Trying https -> http instead."
  65. "Downloading {} to {}".format(url, fpath))
  66. urllib.request.urlretrieve(
  67. url, fpath,
  68. reporthook=gen_bar_updater()
  69. )
  70. # check integrity of downloaded file
  71. if not check_integrity(fpath, md5):
  72. raise RuntimeError("File not found or corrupted.")
  73. return fpath
  74. def download_from_google_drive(id, destination):
  75. # taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
  76. URL = "https://docs.google.com/uc?export=download"
  77. session = requests.Session()
  78. response = session.get(URL, params={'id': id}, stream=True)
  79. token = get_confirm_token(response)
  80. if token:
  81. params = {'id': id, 'confirm': token}
  82. response = session.get(URL, params=params, stream=True)
  83. else:
  84. raise FileNotFoundError("Google drive file id does not exist")
  85. save_response_content(response, destination)
  86. def get_confirm_token(response):
  87. for key, value in response.cookies.items():
  88. if key.startswith('download_warning'):
  89. return value
  90. return None
  91. def save_response_content(response, destination):
  92. CHUNK_SIZE = 32768
  93. with open(destination, "wb") as f:
  94. for chunk in response.iter_content(CHUNK_SIZE):
  95. if chunk: # filter out keep-alive new chunks
  96. f.write(chunk)
  97. def _is_tarxz(filename):
  98. return filename.endswith(".tar.xz")
  99. def _is_tar(filename):
  100. return filename.endswith(".tar")
  101. def _is_targz(filename):
  102. return filename.endswith(".tar.gz")
  103. def _is_tgz(filename):
  104. return filename.endswith(".tgz")
  105. def _is_gzip(filename):
  106. return filename.endswith(".gz") and not filename.endswith(".tar.gz")
  107. def _is_zip(filename):
  108. return filename.endswith(".zip")
  109. def extract_archive(from_path, to_path=None, remove_finished=False):
  110. if to_path is None:
  111. to_path = os.path.dirname(from_path)
  112. if _is_tar(from_path):
  113. with tarfile.open(from_path, 'r') as tar:
  114. tar.extractall(path=to_path)
  115. elif _is_targz(from_path) or _is_tgz(from_path):
  116. with tarfile.open(from_path, 'r:gz') as tar:
  117. tar.extractall(path=to_path)
  118. elif _is_tarxz(from_path):
  119. with tarfile.open(from_path, 'r:xz') as tar:
  120. tar.extractall(path=to_path)
  121. elif _is_gzip(from_path):
  122. to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
  123. with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
  124. out_f.write(zip_f.read())
  125. elif _is_zip(from_path):
  126. with zipfile.ZipFile(from_path, 'r') as z:
  127. z.extractall(to_path)
  128. else:
  129. raise ValueError("file format not supported")
  130. if remove_finished:
  131. os.remove(from_path)