interactive.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import torch as t
  2. from torch.nn import ReLU, Linear, LazyLinear, Tanh, Sigmoid, Dropout, Sequential
  3. from federatedml.nn.backend.torch.base import FateTorchLayer
  4. class InteractiveLayer(t.nn.Module, FateTorchLayer):
  5. r"""A :class: InteractiveLayer.
  6. An interface for InteractiveLayer. In interactive layer, the forward method is:
  7. out = activation( Linear(guest_input) + Linear(host_0_input) + Linear(host_1_input) ..)
  8. Args:
  9. out_dim: int, the output dimension of InteractiveLayer
  10. host_num: int, specify the number of host party, default is 1, need to modify this parameter
  11. when running multi-party modeling
  12. guest_dim: int or None, the input dimension of guest features, if None, will use LazyLinear layer
  13. that automatically infers the input dimension
  14. host_dim: int, or None:
  15. int: the input dimension of all host features
  16. None: automatically infer the input dimension of all host features
  17. activation: str, support relu, tanh, sigmoid
  18. dropout: float in 0-1, if None, dropout is disabled
  19. guest_bias: bias for guest linear layer
  20. host_bias: bias for host linear layers
  21. need_guest: if false, will ignore the input of guest bottom model
  22. """
  23. def __init__(
  24. self,
  25. out_dim,
  26. guest_dim=None,
  27. host_num=1,
  28. host_dim=None,
  29. activation='relu',
  30. dropout=None,
  31. guest_bias=True,
  32. host_bias=True,
  33. need_guest=True,
  34. ):
  35. t.nn.Module.__init__(self)
  36. FateTorchLayer.__init__(self)
  37. self.activation = None
  38. if activation is not None:
  39. if activation.lower() == 'relu':
  40. self.activation = ReLU()
  41. elif activation.lower() == 'tanh':
  42. self.activation = Tanh()
  43. elif activation.lower() == 'sigmoid':
  44. self.activation = Sigmoid()
  45. else:
  46. raise ValueError(
  47. 'activation not support {}, avail: relu, tanh, sigmoid'.format(activation))
  48. self.dropout = None
  49. if dropout is not None:
  50. assert isinstance(dropout, float), 'dropout must be a float'
  51. self.dropout = Dropout(p=dropout)
  52. assert isinstance(out_dim, int), 'out_dim must be an int >= 0'
  53. self.param_dict['out_dim'] = out_dim
  54. self.param_dict['activation'] = activation
  55. self.param_dict['dropout'] = dropout
  56. self.param_dict['need_guest'] = need_guest
  57. assert isinstance(
  58. host_num, int) and host_num >= 1, 'host number is an int >= 1'
  59. self.param_dict['host_num'] = host_num
  60. if guest_dim is not None:
  61. assert isinstance(guest_dim, int)
  62. if host_dim is not None:
  63. assert isinstance(host_dim, int)
  64. self.guest_bias = guest_bias
  65. self.param_dict['guest_dim'] = guest_dim
  66. self.param_dict['host_dim'] = host_dim
  67. self.param_dict['guest_bias'] = guest_bias
  68. self.param_dict['host_bias'] = host_bias
  69. if need_guest:
  70. if guest_dim is None:
  71. self.guest_model = LazyLinear(out_dim, guest_bias)
  72. else:
  73. self.guest_model = Linear(guest_dim, out_dim, guest_bias)
  74. else:
  75. self.guest_model = None
  76. self.out_dim = out_dim
  77. self.host_dim = host_dim
  78. self.host_bias = host_bias
  79. self.host_model = None
  80. self.need_guest = need_guest
  81. self.host_model = t.nn.ModuleList()
  82. for i in range(host_num):
  83. self.host_model.append(self.make_host_model())
  84. if self.dropout is not None:
  85. self.act_seq = Sequential(
  86. self.activation,
  87. self.dropout
  88. )
  89. else:
  90. self.act_seq = Sequential(
  91. self.activation
  92. )
  93. def lazy_to_linear(self, guest_dim=None, host_dims=None):
  94. if isinstance(
  95. self.guest_model,
  96. t.nn.LazyLinear) and guest_dim is not None:
  97. self.guest_model = t.nn.Linear(
  98. guest_dim, self.out_dim, bias=self.guest_bias)
  99. if isinstance(
  100. self.host_model[0],
  101. t.nn.LazyLinear) and host_dims is not None:
  102. new_model_list = t.nn.ModuleList()
  103. for dim in host_dims:
  104. new_model_list.append(
  105. t.nn.Linear(
  106. dim,
  107. self.out_dim,
  108. bias=self.host_bias))
  109. self.host_model = new_model_list
  110. def make_host_model(self):
  111. if self.host_dim is None:
  112. return LazyLinear(self.out_dim, self.host_bias)
  113. else:
  114. return Linear(self.host_dim, self.out_dim, self.host_bias)
  115. def forward(self, x_guest, x_host):
  116. if self.need_guest:
  117. g_out = self.guest_model(x_guest)
  118. else:
  119. g_out = 0
  120. h_out = None
  121. if isinstance(x_host, list):
  122. for m, data in zip(self.host_model, x_host):
  123. out_ = m(data)
  124. if h_out is None:
  125. h_out = out_
  126. else:
  127. h_out += out_
  128. else:
  129. h_out = self.host_model[0](x_host)
  130. return self.activation(g_out + h_out)