123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- """
- These codes are adopted from torchvison with some modifications.
- """
- import gzip
- import hashlib
- import logging
- import os
- import tarfile
- import zipfile
- import requests
- from tqdm import tqdm
- logger = logging.getLogger(__name__)
- def gen_bar_updater():
- pbar = tqdm(total=None)
- def bar_update(count, block_size, total_size):
- if pbar.total is None and total_size:
- pbar.total = total_size
- progress_bytes = count * block_size
- pbar.update(progress_bytes - pbar.n)
- return bar_update
- def calculate_md5(fpath, chunk_size=1024 * 1024):
- md5 = hashlib.md5()
- with open(fpath, 'rb') as f:
- for chunk in iter(lambda: f.read(chunk_size), b''):
- md5.update(chunk)
- return md5.hexdigest()
- def check_md5(fpath, md5, **kwargs):
- return md5 == calculate_md5(fpath, **kwargs)
- def check_integrity(fpath, md5=None):
- if not os.path.isfile(fpath):
- return False
- if md5 is None:
- return True
- return check_md5(fpath, md5)
- def download_url(url, root, filename=None, md5=None):
- """Download a file from a url and place it in root.
- Args:
- url (str): URL to download file from
- root (str): Directory to place downloaded file in
- filename (str, optional): Name to save the file under. If None, use the basename of the URL
- """
- import urllib.request
- import urllib.error
- root = os.path.expanduser(root)
- if not filename:
- filename = os.path.basename(url)
- fpath = os.path.join(root, filename)
- os.makedirs(root, exist_ok=True)
- # check if file is already present locally
- if check_integrity(fpath, md5):
- logger.info("Using downloaded and verified file: " + fpath)
- return fpath
- else: # download the file
- try:
- logger.info("Downloading {} to {}".format(url, fpath))
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater()
- )
- except (urllib.error.URLError, IOError) as e:
- if url[:5] != 'https':
- raise e
- url = url.replace('https:', 'http:')
- logger.info("Failed download. Trying https -> http instead."
- "Downloading {} to {}".format(url, fpath))
- urllib.request.urlretrieve(
- url, fpath,
- reporthook=gen_bar_updater()
- )
- # check integrity of downloaded file
- if not check_integrity(fpath, md5):
- raise RuntimeError("File not found or corrupted.")
- return fpath
- def download_from_google_drive(id, destination):
- # taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
- URL = "https://docs.google.com/uc?export=download"
- session = requests.Session()
- response = session.get(URL, params={'id': id}, stream=True)
- token = get_confirm_token(response)
- if token:
- params = {'id': id, 'confirm': token}
- response = session.get(URL, params=params, stream=True)
- else:
- raise FileNotFoundError("Google drive file id does not exist")
- save_response_content(response, destination)
- def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
- def save_response_content(response, destination):
- CHUNK_SIZE = 32768
- with open(destination, "wb") as f:
- for chunk in response.iter_content(CHUNK_SIZE):
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
- def _is_tarxz(filename):
- return filename.endswith(".tar.xz")
- def _is_tar(filename):
- return filename.endswith(".tar")
- def _is_targz(filename):
- return filename.endswith(".tar.gz")
- def _is_tgz(filename):
- return filename.endswith(".tgz")
- def _is_gzip(filename):
- return filename.endswith(".gz") and not filename.endswith(".tar.gz")
- def _is_zip(filename):
- return filename.endswith(".zip")
- def extract_archive(from_path, to_path=None, remove_finished=False):
- if to_path is None:
- to_path = os.path.dirname(from_path)
- if _is_tar(from_path):
- with tarfile.open(from_path, 'r') as tar:
- tar.extractall(path=to_path)
- elif _is_targz(from_path) or _is_tgz(from_path):
- with tarfile.open(from_path, 'r:gz') as tar:
- tar.extractall(path=to_path)
- elif _is_tarxz(from_path):
- with tarfile.open(from_path, 'r:xz') as tar:
- tar.extractall(path=to_path)
- elif _is_gzip(from_path):
- to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
- with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
- out_f.write(zip_f.read())
- elif _is_zip(from_path):
- with zipfile.ZipFile(from_path, 'r') as z:
- z.extractall(to_path)
- else:
- raise ValueError("file format not supported")
- if remove_finished:
- os.remove(from_path)
|