RunNetwork.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. from tensorflow.keras import Input, Model
  3. from tensorflow.keras.optimizers import Adam
  4. from Network.ExponentialMovingAverage import ExponentialMovingAverage
  5. from Network.GAN.Discriminator import build_discriminator
  6. from Network.GAN.Generator import build_generator
  7. def init_network(latent_dim):
  8. generator = build_generator(latent_dim)
  9. discriminator = build_discriminator()
  10. optimizer = Adam(0.00001, 0.9)
  11. optimizer_d = Adam(0.00001, 0.9)
  12. optimizer_g = Adam(0.00001, 0.9)
  13. discriminator.compile(loss='binary_crossentropy',
  14. optimizer=optimizer_d,
  15. metrics=['accuracy']
  16. )
  17. # ema = ExponentialMovingAverage(discriminator)
  18. # ema.inject()
  19. z = Input(shape=(latent_dim,))
  20. fake = generator(z)
  21. discriminator.trainable = False
  22. validity = discriminator(fake)
  23. combined = Model(z, validity)
  24. combined.compile(loss='binary_crossentropy',
  25. optimizer=optimizer_g,
  26. )
  27. # ema = ExponentialMovingAverage(combined)
  28. # ema.inject()
  29. return generator, discriminator, combined
  30. def run_network(train_x, train_y, epochs, batch_size=50, latent_dim=32):
  31. generator, discriminator, combined = init_network(latent_dim)
  32. valid = np.ones((batch_size, 1))
  33. invalid = np.zeros((batch_size, 1))
  34. for epoch in range(epochs):
  35. # ---------------------
  36. # 训练判别器
  37. # ---------------------
  38. inx = np.random.randint(0, train_x.shape[0], batch_size)
  39. real = train_x[inx]
  40. noise = np.random.normal(0, 1, (batch_size, latent_dim))
  41. fake = generator.predict(noise)
  42. d_loss_real = discriminator.train_on_batch(real, valid)
  43. d_loss_fake = discriminator.train_on_batch(fake, invalid)
  44. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  45. # ---------------------
  46. # 训练生成器
  47. # ---------------------
  48. noise = np.random.normal(0, 1, (batch_size, latent_dim))
  49. g_loss = combined.train_on_batch(noise, valid)
  50. # 显示进度
  51. print("%d [Discriminator loss: %f, accuracy: %.2f%%] [Generator loss: %f]" % (
  52. epoch, d_loss[0], 100 * d_loss[1], g_loss))
  53. return generator