""" These codes are adopted from torchvison with some modifications. """ import gzip import hashlib import logging import os import tarfile import zipfile import requests from tqdm import tqdm logger = logging.getLogger(__name__) def gen_bar_updater(): pbar = tqdm(total=None) def bar_update(count, block_size, total_size): if pbar.total is None and total_size: pbar.total = total_size progress_bytes = count * block_size pbar.update(progress_bytes - pbar.n) return bar_update def calculate_md5(fpath, chunk_size=1024 * 1024): md5 = hashlib.md5() with open(fpath, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) return md5.hexdigest() def check_md5(fpath, md5, **kwargs): return md5 == calculate_md5(fpath, **kwargs) def check_integrity(fpath, md5=None): if not os.path.isfile(fpath): return False if md5 is None: return True return check_md5(fpath, md5) def download_url(url, root, filename=None, md5=None): """Download a file from a url and place it in root. Args: url (str): URL to download file from root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL """ import urllib.request import urllib.error root = os.path.expanduser(root) if not filename: filename = os.path.basename(url) fpath = os.path.join(root, filename) os.makedirs(root, exist_ok=True) # check if file is already present locally if check_integrity(fpath, md5): logger.info("Using downloaded and verified file: " + fpath) return fpath else: # download the file try: logger.info("Downloading {} to {}".format(url, fpath)) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) except (urllib.error.URLError, IOError) as e: if url[:5] != 'https': raise e url = url.replace('https:', 'http:') logger.info("Failed download. Trying https -> http instead." "Downloading {} to {}".format(url, fpath)) urllib.request.urlretrieve( url, fpath, reporthook=gen_bar_updater() ) # check integrity of downloaded file if not check_integrity(fpath, md5): raise RuntimeError("File not found or corrupted.") return fpath def download_from_google_drive(id, destination): # taken from this StackOverflow answer: https://stackoverflow.com/a/39225039 URL = "https://docs.google.com/uc?export=download" session = requests.Session() response = session.get(URL, params={'id': id}, stream=True) token = get_confirm_token(response) if token: params = {'id': id, 'confirm': token} response = session.get(URL, params=params, stream=True) else: raise FileNotFoundError("Google drive file id does not exist") save_response_content(response, destination) def get_confirm_token(response): for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def save_response_content(response, destination): CHUNK_SIZE = 32768 with open(destination, "wb") as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) def _is_tarxz(filename): return filename.endswith(".tar.xz") def _is_tar(filename): return filename.endswith(".tar") def _is_targz(filename): return filename.endswith(".tar.gz") def _is_tgz(filename): return filename.endswith(".tgz") def _is_gzip(filename): return filename.endswith(".gz") and not filename.endswith(".tar.gz") def _is_zip(filename): return filename.endswith(".zip") def extract_archive(from_path, to_path=None, remove_finished=False): if to_path is None: to_path = os.path.dirname(from_path) if _is_tar(from_path): with tarfile.open(from_path, 'r') as tar: tar.extractall(path=to_path) elif _is_targz(from_path) or _is_tgz(from_path): with tarfile.open(from_path, 'r:gz') as tar: tar.extractall(path=to_path) elif _is_tarxz(from_path): with tarfile.open(from_path, 'r:xz') as tar: tar.extractall(path=to_path) elif _is_gzip(from_path): to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: out_f.write(zip_f.read()) elif _is_zip(from_path): with zipfile.ZipFile(from_path, 'r') as z: z.extractall(to_path) else: raise ValueError("file format not supported") if remove_finished: os.remove(from_path)