optimization_strategy.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """Optimization setups."""
  2. from dataclasses import dataclass
  3. def training_strategy(strategy, lr=None, epochs=None, dryrun=False):
  4. """Parse training strategy."""
  5. if strategy == 'conservative':
  6. defs = ConservativeStrategy(lr, epochs, dryrun)
  7. elif strategy == 'adam':
  8. defs = AdamStrategy(lr, epochs, dryrun)
  9. else:
  10. raise ValueError('Unknown training strategy.')
  11. return defs
  12. @dataclass
  13. class Strategy:
  14. """Default usual parameters, not intended for parsing."""
  15. epochs : int
  16. batch_size : int
  17. optimizer : str
  18. lr : float
  19. scheduler : str
  20. weight_decay : float
  21. validate : int
  22. warmup: bool
  23. dryrun : bool
  24. dropout : float
  25. augmentations : bool
  26. def __init__(self, lr=None, epochs=None, dryrun=False):
  27. """Defaulted parameters. Apply overwrites from args."""
  28. if epochs is not None:
  29. self.epochs = epochs
  30. if lr is not None:
  31. self.lr = lr
  32. if dryrun:
  33. self.dryrun = dryrun
  34. self.validate = 10
  35. @dataclass
  36. class ConservativeStrategy(Strategy):
  37. """Default usual parameters, defines a config object."""
  38. def __init__(self, lr=None, epochs=None, dryrun=False):
  39. """Initialize training hyperparameters."""
  40. self.lr = 0.1
  41. self.epochs = 120
  42. self.batch_size = 128
  43. self.optimizer = 'SGD'
  44. self.scheduler = 'linear'
  45. self.warmup = False
  46. self.weight_decay : float = 5e-4
  47. self.dropout = 0.0
  48. self.augmentations = True
  49. self.dryrun = False
  50. super().__init__(lr=None, epochs=None, dryrun=False)
  51. @dataclass
  52. class AdamStrategy(Strategy):
  53. """Start slowly. Use a tame Adam."""
  54. def __init__(self, lr=None, epochs=None, dryrun=False):
  55. """Initialize training hyperparameters."""
  56. self.lr = 1e-3 / 10
  57. self.epochs = 120
  58. self.batch_size = 32
  59. self.optimizer = 'AdamW'
  60. self.scheduler = 'linear'
  61. self.warmup = True
  62. self.weight_decay : float = 5e-4
  63. self.dropout = 0.0
  64. self.augmentations = True
  65. self.dryrun = False
  66. super().__init__(lr=None, epochs=None, dryrun=False)