RunBiGanNetwork.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. from tensorflow.keras import Input, Model
  3. from tensorflow.keras.optimizers import Adam
  4. from Network.BiGAN.Discriminator import build_discriminator
  5. from Network.BiGAN.Encoder import build_encoder
  6. from Network.BiGAN.Generator import build_generator
  7. def init_network(latent_dim):
  8. discriminator = build_discriminator(latent_dim)
  9. optimizer = Adam(0.00001, 0.5)
  10. # optimizer_d = Adam(0.000005, 0.5)
  11. # optimizer_g = Adam(0.00001, 0.5)
  12. discriminator.compile(loss=['binary_crossentropy'],
  13. optimizer=optimizer,
  14. metrics=['accuracy'])
  15. generator = build_generator(latent_dim)
  16. encoder = build_encoder(latent_dim)
  17. discriminator.trainable = False
  18. z = Input(shape=(latent_dim,))
  19. data_ = generator(z)
  20. data = Input(shape=(121,))
  21. z_ = encoder(data)
  22. fake = discriminator([z, data_])
  23. valid = discriminator([z_, data])
  24. bigan_generator = Model([z, data], [fake, valid])
  25. bigan_generator.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
  26. optimizer=optimizer)
  27. return generator, discriminator, encoder, bigan_generator
  28. def run_bigan(train_x, train_y, epochs, batch_size=50, latent_dim=32):
  29. generator, discriminator, encoder, bigan_generator = init_network(latent_dim)
  30. valid = np.ones((batch_size, 1))
  31. fake = np.zeros((batch_size, 1))
  32. for epoch in range(epochs):
  33. # ---------------------
  34. # 训练判别器
  35. # ---------------------
  36. z = np.random.normal(0, 1, (batch_size, latent_dim))
  37. data_ = generator.predict(z)
  38. inx = np.random.randint(0, train_x.shape[0], batch_size)
  39. data = train_x[inx]
  40. z_ = encoder.predict(data)
  41. d_loss_real = discriminator.train_on_batch([z_, data], valid)
  42. d_loss_fake = discriminator.train_on_batch([z, data_], fake)
  43. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  44. # ---------------------
  45. # 训练生成器
  46. # ---------------------
  47. g_loss = bigan_generator.train_on_batch([z, data], [valid, fake])
  48. print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[0]))
  49. return generator