| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import numpy as np
- from tensorflow.keras import Input, Model
- from tensorflow.keras.optimizers import Adam
- from Network.ExponentialMovingAverage import ExponentialMovingAverage
- from Network.GAN.Discriminator import build_discriminator
- from Network.GAN.Generator import build_generator
- def init_network(latent_dim):
- generator = build_generator(latent_dim)
- discriminator = build_discriminator()
- optimizer = Adam(0.00001, 0.9)
- optimizer_d = Adam(0.00001, 0.9)
- optimizer_g = Adam(0.00001, 0.9)
- discriminator.compile(loss='binary_crossentropy',
- optimizer=optimizer_d,
- metrics=['accuracy']
- )
- # ema = ExponentialMovingAverage(discriminator)
- # ema.inject()
- z = Input(shape=(latent_dim,))
- fake = generator(z)
- discriminator.trainable = False
- validity = discriminator(fake)
- combined = Model(z, validity)
- combined.compile(loss='binary_crossentropy',
- optimizer=optimizer_g,
- )
- # ema = ExponentialMovingAverage(combined)
- # ema.inject()
- return generator, discriminator, combined
- def run_network(train_x, train_y, epochs, batch_size=50, latent_dim=32):
- generator, discriminator, combined = init_network(latent_dim)
- valid = np.ones((batch_size, 1))
- invalid = np.zeros((batch_size, 1))
- for epoch in range(epochs):
- # ---------------------
- # 训练判别器
- # ---------------------
- inx = np.random.randint(0, train_x.shape[0], batch_size)
- real = train_x[inx]
- noise = np.random.normal(0, 1, (batch_size, latent_dim))
- fake = generator.predict(noise)
- d_loss_real = discriminator.train_on_batch(real, valid)
- d_loss_fake = discriminator.train_on_batch(fake, invalid)
- d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- # ---------------------
- # 训练生成器
- # ---------------------
- noise = np.random.normal(0, 1, (batch_size, latent_dim))
- g_loss = combined.train_on_batch(noise, valid)
- # 显示进度
- print("%d [Discriminator loss: %f, accuracy: %.2f%%] [Generator loss: %f]" % (
- epoch, d_loss[0], 100 * d_loss[1], g_loss))
- return generator
|