ExponentialMovingAverage.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from keras import backend as K
  2. class ExponentialMovingAverage:
  3. """对模型权重进行指数滑动平均。
  4. 用法:在model.compile之后、第一次训练之前使用;
  5. 先初始化对象,然后执行inject方法。
  6. """
  7. def __init__(self, model, momentum=0.9999):
  8. self.momentum = momentum
  9. self.model = model
  10. self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights]
  11. def inject(self):
  12. """添加更新算子到model.metrics_updates。
  13. """
  14. self.initialize()
  15. for w1, w2 in zip(self.ema_weights, self.model.weights):
  16. op = K.moving_average_update(w1, w2, self.momentum)
  17. self.model.add_metric(op, name='exponential_moving_average')
  18. def initialize(self):
  19. """ema_weights初始化跟原模型初始化一致。
  20. """
  21. self.old_weights = K.batch_get_value(self.model.weights)
  22. K.batch_set_value(zip(self.ema_weights, self.old_weights))
  23. def apply_ema_weights(self):
  24. """备份原模型权重,然后将平均权重应用到模型上去。
  25. """
  26. self.old_weights = K.batch_get_value(self.model.weights)
  27. ema_weights = K.batch_get_value(self.ema_weights)
  28. K.batch_set_value(zip(self.model.weights, ema_weights))
  29. def reset_old_weights(self):
  30. """恢复模型到旧权重。
  31. """
  32. K.batch_set_value(zip(self.model.weights, self.old_weights))