coae.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from federatedml.util import LOGGER
  2. from federatedml.util import consts
  3. try:
  4. import torch
  5. import torch as t
  6. from torch import nn
  7. from torch.nn import Module
  8. from torch.nn import functional as F
  9. except ImportError:
  10. Module = object
  11. def entropy(tensor):
  12. return -t.sum(tensor * t.log2(tensor))
  13. def cross_entropy(p2, p1, reduction='mean'):
  14. p2 = p2 + consts.FLOAT_ZERO # to avoid nan
  15. assert p2.shape == p1.shape
  16. if reduction == 'sum':
  17. return -t.sum(p1 * t.log(p2))
  18. elif reduction == 'mean':
  19. return -t.mean(t.sum(p1 * t.log(p2), dim=1))
  20. elif reduction == 'none':
  21. return -t.sum(p1 * t.log(p2), dim=1)
  22. else:
  23. raise ValueError('unknown reduction')
  24. def cross_entropy_for_one_hot(pred, target, reduce="mean"):
  25. if reduce == "mean":
  26. return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
  27. elif reduce == "sum":
  28. return torch.sum(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
  29. else:
  30. raise Exception("Does not support reduce [{}]".format(reduce))
  31. def coae_loss(
  32. label,
  33. fake_label,
  34. reconstruct_label,
  35. lambda_1=10,
  36. lambda_2=2,
  37. verbose=False):
  38. loss_a = cross_entropy(reconstruct_label, label) - \
  39. lambda_1 * cross_entropy(fake_label, label)
  40. loss_b = entropy(fake_label)
  41. if verbose:
  42. LOGGER.debug(
  43. 'loss a is {} {}'.format(
  44. cross_entropy(
  45. reconstruct_label, label), cross_entropy(
  46. fake_label, label)))
  47. LOGGER.debug('loss b is {}'.format(loss_b))
  48. return loss_a - lambda_2 * loss_b
  49. class CrossEntropy(object):
  50. def __init__(self, reduction='mean'):
  51. self.reduction = reduction
  52. def __call__(self, p2, p1):
  53. return cross_entropy(p2, p1, self.reduction)
  54. class CoAE(Module):
  55. def __init__(self, input_dim=2, encode_dim=None):
  56. super(CoAE, self).__init__()
  57. self.d = input_dim
  58. if encode_dim is None:
  59. encode_dim = (6 * input_dim) ** 2
  60. self.encoder = nn.Sequential(
  61. nn.Linear(input_dim, encode_dim),
  62. nn.ReLU(),
  63. nn.Linear(encode_dim, input_dim),
  64. nn.Softmax(dim=1)
  65. )
  66. self.decoder = nn.Sequential(
  67. nn.Linear(input_dim, encode_dim),
  68. nn.ReLU(),
  69. nn.Linear(encode_dim, input_dim),
  70. nn.Softmax(dim=1)
  71. )
  72. def encode(self, x):
  73. x = t.Tensor(x)
  74. return self.encoder(x)
  75. def decode(self, fake_labels):
  76. fake_labels = t.Tensor(fake_labels)
  77. return self.decoder(fake_labels)
  78. def forward(self, x):
  79. x = t.Tensor(x)
  80. z = self.encoder(x)
  81. return self.decoder(z), z
  82. def train_an_autoencoder_confuser(
  83. label_num,
  84. epoch=50,
  85. lambda1=1,
  86. lambda2=2,
  87. lr=0.001,
  88. verbose=False):
  89. coae = CoAE(label_num, )
  90. labels = torch.eye(label_num)
  91. opt = torch.optim.Adam(coae.parameters(), lr=lr)
  92. for i in range(epoch):
  93. opt.zero_grad()
  94. fake_labels = coae.encode(labels)
  95. reconstruct_labels = coae.decode(fake_labels)
  96. loss = coae_loss(
  97. labels,
  98. fake_labels,
  99. reconstruct_labels,
  100. lambda1,
  101. lambda2,
  102. verbose=verbose)
  103. loss.backward()
  104. opt.step()
  105. if verbose:
  106. LOGGER.debug(
  107. 'origin labels {}, fake labels {}, reconstruct labels {}'.format(
  108. labels, coae.encode(labels).detach().numpy(), coae.decode(
  109. coae.encode(labels)).detach().numpy()))
  110. return coae
  111. def coae_label_reformat(labels, label_num):
  112. LOGGER.debug('label shape is {}'.format(labels.shape))
  113. labels = labels
  114. if label_num == 1: # regression:
  115. raise ValueError('label num ==1, regression task not support COAE')
  116. else:
  117. return nn.functional.one_hot(
  118. t.Tensor(labels).flatten().type(
  119. t.int64), label_num).numpy()
  120. if __name__ == '__main__':
  121. coae = train_an_autoencoder_confuser(
  122. 2,
  123. epoch=1000,
  124. verbose=True,
  125. lambda1=2.0,
  126. lambda2=1.0,
  127. lr=0.02)