utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import logging
  2. from collections import defaultdict
  3. import numpy as np
  4. logger = logging.getLogger(__name__)
  5. class AverageMeter(object):
  6. """Computes and stores the average and current value"""
  7. def __init__(self):
  8. self.reset()
  9. def reset(self):
  10. self.val = 0
  11. self.avg = 0
  12. self.std = 0
  13. self.sum = 0
  14. self.sumsq = 0
  15. self.count = 0
  16. self.lst = []
  17. def update(self, val, n=1):
  18. self.val = float(val)
  19. self.sum += float(val) * n
  20. # self.sumsq += float(val)**2
  21. self.count += n
  22. self.avg = self.sum / self.count
  23. self.lst.append(self.val)
  24. self.std = np.std(self.lst)
  25. class ProgressTable:
  26. def __init__(self, table_list):
  27. if len(table_list) == 0:
  28. print()
  29. return
  30. self.lens = defaultdict(int)
  31. self.table_list = table_list
  32. self.construct(table_list)
  33. def construct(self, table_list):
  34. self.lens = defaultdict(int)
  35. self.table_list = table_list
  36. for i in table_list:
  37. for ii, to_print in enumerate(i):
  38. for title, val in to_print.items():
  39. self.lens[(title, ii)] = max(self.lens[(title, ii)], max(len(title), len(val)))
  40. def print_table_header(self):
  41. for ii, to_print in enumerate(self.table_list[0]):
  42. for title, val in to_print.items():
  43. print('{0:^{1}}'.format(title, self.lens[(title, ii)]), end=" ")
  44. def print_table_content(self):
  45. for i in self.table_list:
  46. print()
  47. for ii, to_print in enumerate(i):
  48. for title, val in to_print.items():
  49. print('{0:^{1}}'.format(val, self.lens[(title, ii)]), end=" ", flush=True)
  50. def print_all_table(self):
  51. self.print_table_header()
  52. self.print_table_content()
  53. def print_table(self, header_condition, content_condition):
  54. if header_condition:
  55. self.print_table_header()
  56. if content_condition:
  57. self.print_table_content()
  58. def update_table_list(self, table_list):
  59. self.construct(table_list)
  60. def print_table(table_list):
  61. if len(table_list) == 0:
  62. print()
  63. return
  64. lens = defaultdict(int)
  65. for i in table_list:
  66. for ii, to_print in enumerate(i):
  67. for title, val in to_print.items():
  68. lens[(title, ii)] = max(lens[(title, ii)], max(len(title), len(val)))
  69. # printed_table_list_header = []
  70. for ii, to_print in enumerate(table_list[0]):
  71. for title, val in to_print.items():
  72. print('{0:^{1}}'.format(title, lens[(title, ii)]), end=" ")
  73. for i in table_list:
  74. print()
  75. for ii, to_print in enumerate(i):
  76. for title, val in to_print.items():
  77. print('{0:^{1}}'.format(val, lens[(title, ii)]), end=" ", flush=True)
  78. print()