| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import numpy as np
- from tensorflow.keras import Input, Model
- from tensorflow.keras.optimizers import Adam
- from Network.BiGAN.Discriminator import build_discriminator
- from Network.BiGAN.Encoder import build_encoder
- from Network.BiGAN.Generator import build_generator
- def init_network(latent_dim):
- discriminator = build_discriminator(latent_dim)
- optimizer = Adam(0.00001, 0.5)
- # optimizer_d = Adam(0.000005, 0.5)
- # optimizer_g = Adam(0.00001, 0.5)
- discriminator.compile(loss=['binary_crossentropy'],
- optimizer=optimizer,
- metrics=['accuracy'])
- generator = build_generator(latent_dim)
- encoder = build_encoder(latent_dim)
- discriminator.trainable = False
- z = Input(shape=(latent_dim,))
- data_ = generator(z)
- data = Input(shape=(121,))
- z_ = encoder(data)
- fake = discriminator([z, data_])
- valid = discriminator([z_, data])
- bigan_generator = Model([z, data], [fake, valid])
- bigan_generator.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
- optimizer=optimizer)
- return generator, discriminator, encoder, bigan_generator
- def run_bigan(train_x, train_y, epochs, batch_size=50, latent_dim=32):
- generator, discriminator, encoder, bigan_generator = init_network(latent_dim)
- valid = np.ones((batch_size, 1))
- fake = np.zeros((batch_size, 1))
- for epoch in range(epochs):
- # ---------------------
- # 训练判别器
- # ---------------------
- z = np.random.normal(0, 1, (batch_size, latent_dim))
- data_ = generator.predict(z)
- inx = np.random.randint(0, train_x.shape[0], batch_size)
- data = train_x[inx]
- z_ = encoder.predict(data)
- d_loss_real = discriminator.train_on_batch([z_, data], valid)
- d_loss_fake = discriminator.train_on_batch([z, data_], fake)
- d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
- # ---------------------
- # 训练生成器
- # ---------------------
- g_loss = bigan_generator.train_on_batch([z, data], [valid, fake])
- print("%d [D loss: %f, acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss[0]))
- return generator
|