utils.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """Various utilities."""
  2. import os
  3. import csv
  4. import torch
  5. import random
  6. import numpy as np
  7. import socket
  8. import datetime
  9. def system_startup(args=None, defs=None):
  10. """Print useful system information."""
  11. # Choose GPU device and print status information:
  12. device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
  13. setup = dict(device=device, dtype=torch.float) # non_blocking=NON_BLOCKING
  14. print('Currently evaluating -------------------------------:')
  15. print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
  16. print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.')
  17. if args is not None:
  18. print(args)
  19. if defs is not None:
  20. print(repr(defs))
  21. if torch.cuda.is_available():
  22. print(f'GPU : {torch.cuda.get_device_name(device=device)}')
  23. return setup
  24. def save_to_table(out_dir, name, dryrun, **kwargs):
  25. """Save keys to .csv files. Function adapted from Micah."""
  26. # Check for file
  27. if not os.path.isdir(out_dir):
  28. os.makedirs(out_dir)
  29. fname = os.path.join(out_dir, f'table_{name}.csv')
  30. fieldnames = list(kwargs.keys())
  31. # Read or write header
  32. try:
  33. with open(fname, 'r') as f:
  34. reader = csv.reader(f, delimiter='\t')
  35. header = [line for line in reader][0]
  36. except Exception as e:
  37. print('Creating a new .csv table...')
  38. with open(fname, 'w') as f:
  39. writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
  40. writer.writeheader()
  41. if not dryrun:
  42. # Add row for this experiment
  43. with open(fname, 'a') as f:
  44. writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
  45. writer.writerow(kwargs)
  46. print('\nResults saved to ' + fname + '.')
  47. else:
  48. print(f'Would save results to {fname}.')
  49. print(f'Would save these keys: {fieldnames}.')
  50. def set_random_seed(seed=233):
  51. """233 = 144 + 89 is my favorite number."""
  52. torch.manual_seed(seed + 1)
  53. torch.cuda.manual_seed(seed + 2)
  54. torch.cuda.manual_seed_all(seed + 3)
  55. np.random.seed(seed + 4)
  56. torch.cuda.manual_seed_all(seed + 5)
  57. random.seed(seed + 6)
  58. def set_deterministic():
  59. """Switch pytorch into a deterministic computation mode."""
  60. torch.backends.cudnn.deterministic = True
  61. torch.backends.cudnn.benchmark = False