| 123456789101112131415161718192021222324252627282930313233343536373839 |
- from keras import backend as K
- class ExponentialMovingAverage:
- """对模型权重进行指数滑动平均。
- 用法:在model.compile之后、第一次训练之前使用;
- 先初始化对象,然后执行inject方法。
- """
- def __init__(self, model, momentum=0.9999):
- self.momentum = momentum
- self.model = model
- self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights]
- def inject(self):
- """添加更新算子到model.metrics_updates。
- """
- self.initialize()
- for w1, w2 in zip(self.ema_weights, self.model.weights):
- op = K.moving_average_update(w1, w2, self.momentum)
- self.model.add_metric(op, name='exponential_moving_average')
- def initialize(self):
- """ema_weights初始化跟原模型初始化一致。
- """
- self.old_weights = K.batch_get_value(self.model.weights)
- K.batch_set_value(zip(self.ema_weights, self.old_weights))
- def apply_ema_weights(self):
- """备份原模型权重,然后将平均权重应用到模型上去。
- """
- self.old_weights = K.batch_get_value(self.model.weights)
- ema_weights = K.batch_get_value(self.ema_weights)
- K.batch_set_value(zip(self.model.weights, ema_weights))
- def reset_old_weights(self):
- """恢复模型到旧权重。
- """
- K.batch_set_value(zip(self.model.weights, self.old_weights))
|