he_interactive_layer.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import pickle
  17. import numpy as np
  18. import torch
  19. from torch import autograd
  20. from federatedml.nn.hetero.interactive.base import InteractiveLayerGuest, InteractiveLayerHost
  21. from federatedml.nn.hetero.nn_component.torch_model import backward_loss
  22. from federatedml.nn.backend.torch.interactive import InteractiveLayer
  23. from federatedml.nn.backend.torch.serialization import recover_sequential_from_dict
  24. from federatedml.util.fixpoint_solver import FixedPointEncoder
  25. from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import InteractiveLayerParam
  26. from federatedml.secureprotol import PaillierEncrypt
  27. from federatedml.util import consts, LOGGER
  28. from federatedml.nn.hetero.interactive.utils.numpy_layer import NumpyDenseLayerGuest, NumpyDenseLayerHost
  29. from federatedml.secureprotol.paillier_tensor import PaillierTensor
  30. from federatedml.nn.hetero.nn_component.torch_model import TorchNNModel
  31. from federatedml.transfer_variable.base_transfer_variable import BaseTransferVariables
  32. from fate_arch.session import computing_session as session
  33. from federatedml.nn.backend.utils.rng import RandomNumberGenerator
  34. PLAINTEXT = False
  35. class HEInteractiveTransferVariable(BaseTransferVariables):
  36. def __init__(self, flowid=0):
  37. super().__init__(flowid)
  38. self.decrypted_guest_forward = self._create_variable(
  39. name='decrypted_guest_forward', src=['host'], dst=['guest'])
  40. self.decrypted_guest_weight_gradient = self._create_variable(
  41. name='decrypted_guest_weight_gradient', src=['host'], dst=['guest'])
  42. self.encrypted_acc_noise = self._create_variable(
  43. name='encrypted_acc_noise', src=['host'], dst=['guest'])
  44. self.encrypted_guest_forward = self._create_variable(
  45. name='encrypted_guest_forward', src=['guest'], dst=['host'])
  46. self.encrypted_guest_weight_gradient = self._create_variable(
  47. name='encrypted_guest_weight_gradient', src=['guest'], dst=['host'])
  48. self.encrypted_host_forward = self._create_variable(
  49. name='encrypted_host_forward', src=['host'], dst=['guest'])
  50. self.host_backward = self._create_variable(
  51. name='host_backward', src=['guest'], dst=['host'])
  52. self.selective_info = self._create_variable(
  53. name="selective_info", src=["guest"], dst=["host"])
  54. self.drop_out_info = self._create_variable(
  55. name="drop_out_info", src=["guest"], dst=["host"])
  56. self.drop_out_table = self._create_variable(
  57. name="drop_out_table", src=["guest"], dst=["host"])
  58. self.interactive_layer_output_unit = self._create_variable(
  59. name="interactive_layer_output_unit", src=["guest"], dst=["host"])
  60. class DropOut(object):
  61. def __init__(self, rate, noise_shape):
  62. self._keep_rate = rate
  63. self._noise_shape = noise_shape
  64. self._batch_size = noise_shape[0]
  65. self._mask = None
  66. self._partition = None
  67. self._mask_table = None
  68. self._select_mask_table = None
  69. self._do_backward_select = False
  70. self._mask_table_cache = {}
  71. def forward(self, X):
  72. if X.shape == self._mask.shape:
  73. forward_x = X * self._mask / self._keep_rate
  74. else:
  75. forward_x = X * self._mask[0: len(X)] / self._keep_rate
  76. return forward_x
  77. def backward(self, grad):
  78. if self._do_backward_select:
  79. self._mask = self._select_mask_table[0: grad.shape[0]]
  80. self._select_mask_table = self._select_mask_table[grad.shape[0]:]
  81. return grad * self._mask / self._keep_rate
  82. else:
  83. if grad.shape == self._mask.shape:
  84. return grad * self._mask / self._keep_rate
  85. else:
  86. return grad * self._mask[0: grad.shape[0]] / self._keep_rate
  87. def generate_mask(self):
  88. self._mask = np.random.uniform(
  89. low=0, high=1, size=self._noise_shape) < self._keep_rate
  90. def generate_mask_table(self, shape):
  91. # generate mask table according to samples shape, because in some
  92. # batches, sample_num < batch_size
  93. if shape == self._noise_shape:
  94. _mask_table = session.parallelize(
  95. self._mask, include_key=False, partition=self._partition)
  96. else:
  97. _mask_table = session.parallelize(
  98. self._mask[0: shape[0]], include_key=False, partition=self._partition)
  99. return _mask_table
  100. def set_partition(self, partition):
  101. self._partition = partition
  102. def select_backward_sample(self, select_ids):
  103. select_mask_table = self._mask[np.array(select_ids)]
  104. if self._select_mask_table is not None:
  105. self._select_mask_table = np.vstack(
  106. (self._select_mask_table, select_mask_table))
  107. else:
  108. self._select_mask_table = select_mask_table
  109. def do_backward_select_strategy(self):
  110. self._do_backward_select = True
  111. class HEInteractiveLayerGuest(InteractiveLayerGuest):
  112. def __init__(self, params=None, layer_config=None, host_num=1):
  113. super(HEInteractiveLayerGuest, self).__init__(params)
  114. # transfer var
  115. self.host_num = host_num
  116. self.layer_config = layer_config
  117. self.transfer_variable = HEInteractiveTransferVariable()
  118. self.plaintext = PLAINTEXT
  119. self.layer_config = layer_config
  120. self.host_input_shapes = []
  121. self.rng_generator = RandomNumberGenerator()
  122. self.learning_rate = params.interactive_layer_lr
  123. # cached tensor
  124. self.guest_tensor = None
  125. self.host_tensors = None
  126. self.dense_output_data_require_grad = None
  127. self.activation_out_require_grad = None
  128. # model
  129. self.model: InteractiveLayer = None
  130. self.guest_model = None
  131. self.host_model_list = []
  132. self.batch_size = None
  133. self.partitions = 0
  134. self.do_backward_select_strategy = False
  135. self.optimizer = None
  136. # drop out
  137. self.drop_out_initiated = False
  138. self.drop_out = None
  139. self.drop_out_keep_rate = None
  140. self.fixed_point_encoder = None if params.floating_point_precision is None else FixedPointEncoder(
  141. 2 ** params.floating_point_precision)
  142. self.send_output_unit = False
  143. # float64
  144. self.float64 = False
  145. """
  146. Init functions
  147. """
  148. def set_flow_id(self, flow_id):
  149. self.transfer_variable.set_flowid(flow_id)
  150. def set_backward_select_strategy(self):
  151. self.do_backward_select_strategy = True
  152. def set_batch(self, batch_size):
  153. self.batch_size = batch_size
  154. def set_partition(self, partition):
  155. self.partitions = partition
  156. def _build_model(self):
  157. if self.model is None:
  158. raise ValueError('torch interactive model is not initialized!')
  159. for i in range(self.host_num):
  160. host_model = NumpyDenseLayerHost()
  161. host_model.build(self.model.host_model[i])
  162. host_model.set_learning_rate(self.learning_rate)
  163. self.host_model_list.append(host_model)
  164. self.guest_model = NumpyDenseLayerGuest()
  165. self.guest_model.build(self.model.guest_model)
  166. self.guest_model.set_learning_rate(self.learning_rate)
  167. if self.do_backward_select_strategy:
  168. self.guest_model.set_backward_selective_strategy()
  169. self.guest_model.set_batch(self.batch_size)
  170. for host_model in self.host_model_list:
  171. host_model.set_backward_selective_strategy()
  172. host_model.set_batch(self.batch_size)
  173. """
  174. Drop out functions
  175. """
  176. def init_drop_out_parameter(self):
  177. if isinstance(self.model.param_dict['dropout'], float):
  178. self.drop_out_keep_rate = 1 - self.model.param_dict['dropout']
  179. else:
  180. self.drop_out_keep_rate = -1
  181. self.transfer_variable.drop_out_info.remote(
  182. self.drop_out_keep_rate, idx=-1, suffix=('dropout_rate', ))
  183. self.drop_out_initiated = True
  184. def _create_drop_out(self, shape):
  185. if self.drop_out_keep_rate and self.drop_out_keep_rate != 1 and self.drop_out_keep_rate > 0:
  186. if not self.drop_out:
  187. self.drop_out = DropOut(
  188. noise_shape=shape, rate=self.drop_out_keep_rate)
  189. self.drop_out.set_partition(self.partitions)
  190. if self.do_backward_select_strategy:
  191. self.drop_out.do_backward_select_strategy()
  192. self.drop_out.generate_mask()
  193. @staticmethod
  194. def expand_columns(tensor, keep_array):
  195. shape = keep_array.shape
  196. tensor = np.reshape(tensor, (tensor.size,))
  197. keep = np.reshape(keep_array, (keep_array.size,))
  198. ret_tensor = []
  199. idx = 0
  200. for x in keep:
  201. if x == 0:
  202. ret_tensor.append(0)
  203. else:
  204. ret_tensor.append(tensor[idx])
  205. idx += 1
  206. return np.reshape(np.array(ret_tensor), shape)
  207. """
  208. Plaintext forward/backward, these interfaces are for testing
  209. """
  210. def plaintext_forward(self, guest_input, epoch=0, batch=0, train=True):
  211. if self.model is None:
  212. self.model = recover_sequential_from_dict(self.layer_config)[0]
  213. if self.float64:
  214. self.model.type(torch.float64)
  215. if self.optimizer is None:
  216. self.optimizer = torch.optim.SGD(
  217. params=self.model.parameters(), lr=self.learning_rate)
  218. if train:
  219. self.model.train()
  220. else:
  221. self.model.eval()
  222. with torch.no_grad():
  223. guest_tensor = torch.from_numpy(guest_input)
  224. host_inputs = self.get_forward_from_host(
  225. epoch, batch, train, idx=-1)
  226. host_tensors = [torch.from_numpy(arr) for arr in host_inputs]
  227. interactive_out = self.model(guest_tensor, host_tensors)
  228. self.guest_tensor = guest_tensor
  229. self.host_tensors = host_tensors
  230. return interactive_out.cpu().detach().numpy()
  231. def plaintext_backward(self, output_gradient, epoch, batch):
  232. # compute input gradient
  233. self.guest_tensor: torch.Tensor = self.guest_tensor.requires_grad_(True)
  234. for tensor in self.host_tensors:
  235. tensor.requires_grad_(True)
  236. out = self.model(self.guest_tensor, self.host_tensors)
  237. loss = backward_loss(out, torch.from_numpy(output_gradient))
  238. backward_list = [self.guest_tensor]
  239. backward_list.extend(self.host_tensors)
  240. ret_grad = autograd.grad(loss, backward_list)
  241. # update model
  242. self.guest_tensor: torch.Tensor = self.guest_tensor.requires_grad_(False)
  243. for tensor in self.host_tensors:
  244. tensor.requires_grad_(False)
  245. self.optimizer.zero_grad()
  246. out = self.model(self.guest_tensor, self.host_tensors)
  247. loss = backward_loss(out, torch.from_numpy(output_gradient))
  248. loss.backward()
  249. self.optimizer.step()
  250. self.guest_tensor, self.host_tensors = None, None
  251. for idx, host_grad in enumerate(ret_grad[1:]):
  252. self.send_host_backward_to_host(host_grad, epoch, batch, idx=idx)
  253. return ret_grad[0]
  254. """
  255. Activation forward & backward
  256. """
  257. def activation_forward(self, dense_out, with_grad=True):
  258. if with_grad:
  259. if (self.dense_output_data_require_grad is not None) or (
  260. self.activation_out_require_grad is not None):
  261. raise ValueError(
  262. 'torch forward error, related required grad tensors are not freed')
  263. self.dense_output_data_require_grad = dense_out.requires_grad_(
  264. True)
  265. activation_out_ = self.model.activation(
  266. self.dense_output_data_require_grad)
  267. self.activation_out_require_grad = activation_out_
  268. else:
  269. with torch.no_grad():
  270. activation_out_ = self.model.activation(dense_out)
  271. return activation_out_.cpu().detach().numpy()
  272. def activation_backward(self, output_gradients):
  273. if self.activation_out_require_grad is None and self.dense_output_data_require_grad is None:
  274. raise ValueError('related grad is None, cannot compute backward')
  275. loss = backward_loss(
  276. self.activation_out_require_grad,
  277. torch.Tensor(output_gradients))
  278. activation_backward_grad = torch.autograd.grad(
  279. loss, self.dense_output_data_require_grad)
  280. self.activation_out_require_grad = None
  281. self.dense_output_data_require_grad = None
  282. return activation_backward_grad[0].cpu().detach().numpy()
  283. """
  284. Forward & Backward
  285. """
  286. def print_log(self, descr, epoch, batch, train):
  287. if train:
  288. LOGGER.info("{} epoch {} batch {}"
  289. "".format(descr, epoch, batch))
  290. else:
  291. LOGGER.info("predicting, {} pred iteration {} batch {}"
  292. "".format(descr, epoch, batch))
  293. def forward_interactive(
  294. self,
  295. encrypted_host_input,
  296. epoch,
  297. batch,
  298. train=True):
  299. self.print_log(
  300. 'get encrypted dense output of host model of',
  301. epoch,
  302. batch,
  303. train)
  304. mask_table_list = []
  305. guest_nosies = []
  306. host_idx = 0
  307. for model, host_bottom_input in zip(
  308. self.host_model_list, encrypted_host_input):
  309. encrypted_fw = model(host_bottom_input, self.fixed_point_encoder)
  310. mask_table = None
  311. if train:
  312. self._create_drop_out(encrypted_fw.shape)
  313. if self.drop_out:
  314. mask_table = self.drop_out.generate_mask_table(
  315. encrypted_fw.shape)
  316. if mask_table:
  317. encrypted_fw = encrypted_fw.select_columns(mask_table)
  318. mask_table_list.append(mask_table)
  319. guest_forward_noise = self.rng_generator.fast_generate_random_number(
  320. encrypted_fw.shape, encrypted_fw.partitions, keep_table=mask_table)
  321. if self.fixed_point_encoder:
  322. encrypted_fw += guest_forward_noise.encode(
  323. self.fixed_point_encoder)
  324. else:
  325. encrypted_fw += guest_forward_noise
  326. guest_nosies.append(guest_forward_noise)
  327. self.send_guest_encrypted_forward_output_with_noise_to_host(
  328. encrypted_fw.get_obj(), epoch, batch, idx=host_idx)
  329. if mask_table:
  330. self.send_interactive_layer_drop_out_table(
  331. mask_table, epoch, batch, idx=host_idx)
  332. host_idx += 1
  333. # get list from hosts
  334. decrypted_dense_outputs = self.get_guest_decrypted_forward_from_host(
  335. epoch, batch, idx=-1)
  336. merge_output = None
  337. for idx, (outputs, noise) in enumerate(
  338. zip(decrypted_dense_outputs, guest_nosies)):
  339. out = PaillierTensor(outputs) - noise
  340. if len(mask_table_list) != 0:
  341. out = PaillierTensor(
  342. out.get_obj().join(
  343. mask_table_list[idx],
  344. self.expand_columns))
  345. if merge_output is None:
  346. merge_output = out
  347. else:
  348. merge_output = merge_output + out
  349. return merge_output
  350. def forward(self, x, epoch: int, batch: int, train: bool = True, **kwargs):
  351. self.print_log(
  352. 'interactive layer running forward propagation',
  353. epoch,
  354. batch,
  355. train)
  356. if self.plaintext:
  357. return self.plaintext_forward(x, epoch, batch, train)
  358. if self.model is None:
  359. self.model = recover_sequential_from_dict(self.layer_config)[0]
  360. LOGGER.debug('interactive model is {}'.format(self.model))
  361. # for multi host cases
  362. LOGGER.debug(
  363. 'host num is {}, len host model {}'.format(
  364. self.host_num, len(
  365. self.model.host_model)))
  366. assert self.host_num == len(self.model.host_model), 'host number is {}, but host linear layer number is {},' \
  367. 'please check your interactive configuration, make sure' \
  368. ' that host layer number equals to host number' \
  369. .format(self.host_num, len(self.model.host_model))
  370. if self.float64:
  371. self.model.type(torch.float64)
  372. if train and not self.drop_out_initiated:
  373. self.init_drop_out_parameter()
  374. host_inputs = self.get_forward_from_host(epoch, batch, train, idx=-1)
  375. host_bottom_inputs_tensor = []
  376. host_input_shapes = []
  377. for i in host_inputs:
  378. pt = PaillierTensor(i)
  379. host_bottom_inputs_tensor.append(pt)
  380. host_input_shapes.append(pt.shape[1])
  381. self.model.lazy_to_linear(x.shape[1], host_dims=host_input_shapes)
  382. self.host_input_shapes = host_input_shapes
  383. if self.guest_model is None:
  384. LOGGER.info("building interactive layers' training model")
  385. self._build_model()
  386. if not self.partitions:
  387. self.partitions = host_bottom_inputs_tensor[0].partitions
  388. if not self.send_output_unit:
  389. self.send_output_unit = True
  390. for idx in range(self.host_num):
  391. self.send_interactive_layer_output_unit(
  392. self.host_model_list[idx].output_shape[0], idx=idx)
  393. guest_output = self.guest_model(x)
  394. host_output = self.forward_interactive(
  395. host_bottom_inputs_tensor, epoch, batch, train)
  396. if guest_output is not None:
  397. dense_output_data = host_output + \
  398. PaillierTensor(guest_output, partitions=self.partitions)
  399. else:
  400. dense_output_data = host_output
  401. self.print_log(
  402. "start to get interactive layer's activation output of",
  403. epoch,
  404. batch,
  405. train)
  406. if self.float64: # result after encrypt calculation is float 64
  407. dense_out = torch.from_numpy(dense_output_data.numpy())
  408. else:
  409. dense_out = torch.Tensor(
  410. dense_output_data.numpy()) # convert to float32
  411. if self.do_backward_select_strategy:
  412. for h in self.host_model_list:
  413. h.activation_input = dense_out.cpu().detach().numpy()
  414. # if is not backward strategy, can compute grad directly
  415. if not train or self.do_backward_select_strategy:
  416. with_grad = False
  417. else:
  418. with_grad = True
  419. activation_out = self.activation_forward(
  420. dense_out, with_grad=with_grad)
  421. if train and self.drop_out:
  422. return self.drop_out.forward(activation_out)
  423. return activation_out
  424. def backward_interactive(
  425. self,
  426. host_model,
  427. activation_gradient,
  428. epoch,
  429. batch,
  430. host_idx):
  431. LOGGER.info(
  432. "get encrypted weight gradient of epoch {} batch {}".format(
  433. epoch, batch))
  434. encrypted_weight_gradient = host_model.get_weight_gradient(
  435. activation_gradient, encoder=self.fixed_point_encoder)
  436. if self.fixed_point_encoder:
  437. encrypted_weight_gradient = self.fixed_point_encoder.decode(
  438. encrypted_weight_gradient)
  439. noise_w = self.rng_generator.generate_random_number(
  440. encrypted_weight_gradient.shape)
  441. self.transfer_variable.encrypted_guest_weight_gradient.remote(
  442. encrypted_weight_gradient +
  443. noise_w,
  444. role=consts.HOST,
  445. idx=host_idx,
  446. suffix=(
  447. epoch,
  448. batch,
  449. ))
  450. LOGGER.info(
  451. "get decrypted weight graident of epoch {} batch {}".format(
  452. epoch, batch))
  453. decrypted_weight_gradient = self.transfer_variable.decrypted_guest_weight_gradient.get(
  454. idx=host_idx, suffix=(epoch, batch,))
  455. decrypted_weight_gradient -= noise_w
  456. encrypted_acc_noise = self.get_encrypted_acc_noise_from_host(
  457. epoch, batch, idx=host_idx)
  458. return decrypted_weight_gradient, encrypted_acc_noise
  459. def backward(self, error, epoch: int, batch: int, selective_ids=None):
  460. if self.plaintext:
  461. return self.plaintext_backward(error, epoch, batch)
  462. if selective_ids:
  463. for host_model in self.host_model_list:
  464. host_model.select_backward_sample(selective_ids)
  465. self.guest_model.select_backward_sample(selective_ids)
  466. if self.drop_out:
  467. self.drop_out.select_backward_sample(selective_ids)
  468. if self.do_backward_select_strategy:
  469. # send to all host
  470. self.send_backward_select_info(
  471. selective_ids, len(error), epoch, batch, -1)
  472. if len(error) > 0:
  473. LOGGER.debug(
  474. "interactive layer start backward propagation of epoch {} batch {}".format(
  475. epoch, batch))
  476. if not self.do_backward_select_strategy:
  477. activation_gradient = self.activation_backward(error)
  478. else:
  479. act_input = self.host_model_list[0].get_selective_activation_input(
  480. )
  481. _ = self.activation_forward(torch.from_numpy(act_input), True)
  482. activation_gradient = self.activation_backward(error)
  483. if self.drop_out:
  484. activation_gradient = self.drop_out.backward(
  485. activation_gradient)
  486. LOGGER.debug(
  487. "interactive layer update guest weight of epoch {} batch {}".format(
  488. epoch, batch))
  489. # update guest model
  490. guest_input_gradient = self.update_guest(activation_gradient)
  491. LOGGER.debug('update host model weights')
  492. for idx, host_model in enumerate(self.host_model_list):
  493. # update host models
  494. host_weight_gradient, acc_noise = self.backward_interactive(
  495. host_model, activation_gradient, epoch, batch, host_idx=idx)
  496. host_input_gradient = self.update_host(
  497. host_model, activation_gradient, host_weight_gradient, acc_noise)
  498. self.send_host_backward_to_host(
  499. host_input_gradient.get_obj(), epoch, batch, idx=idx)
  500. return guest_input_gradient
  501. else:
  502. return []
  503. """
  504. Model update
  505. """
  506. def update_guest(self, activation_gradient):
  507. input_gradient = self.guest_model.get_input_gradient(
  508. activation_gradient)
  509. weight_gradient = self.guest_model.get_weight_gradient(
  510. activation_gradient)
  511. self.guest_model.update_weight(weight_gradient)
  512. self.guest_model.update_bias(activation_gradient)
  513. return input_gradient
  514. def update_host(
  515. self,
  516. host_model,
  517. activation_gradient,
  518. weight_gradient,
  519. acc_noise):
  520. activation_gradient_tensor = PaillierTensor(
  521. activation_gradient, partitions=self.partitions)
  522. input_gradient = host_model.get_input_gradient(
  523. activation_gradient_tensor, acc_noise, encoder=self.fixed_point_encoder)
  524. host_model.update_weight(weight_gradient)
  525. host_model.update_bias(activation_gradient)
  526. return input_gradient
  527. """
  528. Communication functions
  529. """
  530. def send_interactive_layer_output_unit(self, shape, idx=0):
  531. self.transfer_variable.interactive_layer_output_unit.remote(
  532. shape, role=consts.HOST, idx=idx)
  533. def send_backward_select_info(
  534. self,
  535. selective_ids,
  536. gradient_len,
  537. epoch,
  538. batch,
  539. idx):
  540. self.transfer_variable.selective_info.remote(
  541. (selective_ids, gradient_len), role=consts.HOST, idx=idx, suffix=(
  542. epoch, batch,))
  543. def send_host_backward_to_host(self, host_error, epoch, batch, idx):
  544. self.transfer_variable.host_backward.remote(host_error,
  545. role=consts.HOST,
  546. idx=idx,
  547. suffix=(epoch, batch,))
  548. def get_forward_from_host(self, epoch, batch, train, idx=0):
  549. return self.transfer_variable.encrypted_host_forward.get(
  550. idx=idx, suffix=(epoch, batch, train))
  551. def send_guest_encrypted_forward_output_with_noise_to_host(
  552. self, encrypted_guest_forward_with_noise, epoch, batch, idx):
  553. return self.transfer_variable.encrypted_guest_forward.remote(
  554. encrypted_guest_forward_with_noise,
  555. role=consts.HOST,
  556. idx=idx,
  557. suffix=(
  558. epoch,
  559. batch,
  560. ))
  561. def send_interactive_layer_drop_out_table(
  562. self, mask_table, epoch, batch, idx):
  563. return self.transfer_variable.drop_out_table.remote(
  564. mask_table, role=consts.HOST, idx=idx, suffix=(epoch, batch,))
  565. def get_guest_decrypted_forward_from_host(self, epoch, batch, idx=0):
  566. return self.transfer_variable.decrypted_guest_forward.get(
  567. idx=idx, suffix=(epoch, batch,))
  568. def get_encrypted_acc_noise_from_host(self, epoch, batch, idx=0):
  569. return self.transfer_variable.encrypted_acc_noise.get(
  570. idx=idx, suffix=(epoch, batch,))
  571. """
  572. Model IO
  573. """
  574. def transfer_np_model_to_torch_interactive_layer(self):
  575. self.model = self.model.cpu()
  576. if self.guest_model is not None:
  577. guest_weight = self.guest_model.get_weight()
  578. model: torch.nn.Linear = self.model.guest_model
  579. model.weight.data.copy_(torch.Tensor(guest_weight))
  580. if self.guest_model.bias is not None:
  581. model.bias.data.copy_(torch.Tensor(self.guest_model.bias))
  582. for host_np_model, torch_model in zip(
  583. self.host_model_list, self.model.host_model):
  584. host_weight = host_np_model.get_weight()
  585. torch_model.weight.data.copy_(torch.Tensor(host_weight))
  586. if host_np_model.bias is not None:
  587. torch_model.bias.data.copy_(torch.Tensor(torch_model.bias))
  588. def export_model(self):
  589. self.transfer_np_model_to_torch_interactive_layer()
  590. interactive_layer_param = InteractiveLayerParam()
  591. interactive_layer_param.interactive_guest_saved_model_bytes = TorchNNModel.get_model_bytes(
  592. self.model)
  593. interactive_layer_param.host_input_shape.extend(self.host_input_shapes)
  594. return interactive_layer_param
  595. def restore_model(self, interactive_layer_param):
  596. self.host_input_shapes = list(interactive_layer_param.host_input_shape)
  597. self.model = TorchNNModel.recover_model_bytes(
  598. interactive_layer_param.interactive_guest_saved_model_bytes)
  599. self._build_model()
  600. class HEInteractiveLayerHost(InteractiveLayerHost):
  601. def __init__(self, params):
  602. super(HEInteractiveLayerHost, self).__init__(params)
  603. self.plaintext = PLAINTEXT
  604. self.acc_noise = None
  605. self.learning_rate = params.interactive_layer_lr
  606. self.encrypter = self.generate_encrypter(params)
  607. self.transfer_variable = HEInteractiveTransferVariable()
  608. self.partitions = 1
  609. self.input_shape = None
  610. self.output_unit = None
  611. self.rng_generator = RandomNumberGenerator()
  612. self.do_backward_select_strategy = False
  613. self.drop_out_init = False
  614. self.drop_out_keep_rate = None
  615. self.fixed_point_encoder = None if params.floating_point_precision is None else FixedPointEncoder(
  616. 2 ** params.floating_point_precision)
  617. self.mask_table = None
  618. """
  619. Init
  620. """
  621. def set_transfer_variable(self, transfer_variable):
  622. self.transfer_variable = transfer_variable
  623. def set_partition(self, partition):
  624. self.partitions = partition
  625. def set_backward_select_strategy(self):
  626. self.do_backward_select_strategy = True
  627. """
  628. Forward & Backward
  629. """
  630. def plaintext_forward(self, host_input, epoch, batch, train):
  631. self.send_forward_to_guest(host_input, epoch, batch, train)
  632. def plaintext_backward(self, epoch, batch):
  633. return self.get_host_backward_from_guest(epoch, batch)
  634. def forward(self, host_input, epoch=0, batch=0, train=True, **kwargs):
  635. if self.plaintext:
  636. self.plaintext_forward(host_input, epoch, batch, train)
  637. return
  638. if train and not self.drop_out_init:
  639. self.drop_out_init = True
  640. self.drop_out_keep_rate = self.transfer_variable.drop_out_info.get(
  641. 0, role=consts.GUEST, suffix=('dropout_rate', ))
  642. if self.drop_out_keep_rate == -1:
  643. self.drop_out_keep_rate = None
  644. LOGGER.info(
  645. "forward propagation: encrypt host_bottom_output of epoch {} batch {}".format(
  646. epoch, batch))
  647. host_input = PaillierTensor(host_input, partitions=self.partitions)
  648. encrypted_host_input = host_input.encrypt(self.encrypter)
  649. self.send_forward_to_guest(
  650. encrypted_host_input.get_obj(), epoch, batch, train)
  651. encrypted_guest_forward = PaillierTensor(
  652. self.get_guest_encrypted_forward_from_guest(epoch, batch))
  653. decrypted_guest_forward = encrypted_guest_forward.decrypt(
  654. self.encrypter)
  655. if self.fixed_point_encoder:
  656. decrypted_guest_forward = decrypted_guest_forward.decode(
  657. self.fixed_point_encoder)
  658. if self.input_shape is None:
  659. self.input_shape = host_input.shape[1]
  660. self.output_unit = self.get_interactive_layer_output_unit()
  661. if self.acc_noise is None:
  662. self.acc_noise = np.zeros((self.input_shape, self.output_unit))
  663. mask_table = None
  664. if train and self.drop_out_keep_rate and self.drop_out_keep_rate < 1:
  665. mask_table = self.get_interactive_layer_drop_out_table(
  666. epoch, batch)
  667. if mask_table:
  668. decrypted_guest_forward_with_noise = decrypted_guest_forward + \
  669. (host_input * self.acc_noise).select_columns(mask_table)
  670. self.mask_table = mask_table
  671. else:
  672. noise_part = (host_input * self.acc_noise)
  673. decrypted_guest_forward_with_noise = decrypted_guest_forward + noise_part
  674. self.send_decrypted_guest_forward_with_noise_to_guest(
  675. decrypted_guest_forward_with_noise.get_obj(), epoch, batch)
  676. def backward(self, epoch, batch):
  677. if self.plaintext:
  678. return self.plaintext_backward(epoch, batch), []
  679. do_backward = True
  680. selective_ids = []
  681. if self.do_backward_select_strategy:
  682. selective_ids, do_backward = self.send_backward_select_info(
  683. epoch, batch)
  684. if not do_backward:
  685. return [], selective_ids
  686. encrypted_guest_weight_gradient = self.get_guest_encrypted_weight_gradient_from_guest(
  687. epoch, batch)
  688. LOGGER.info(
  689. "decrypt weight gradient of epoch {} batch {}".format(
  690. epoch, batch))
  691. decrypted_guest_weight_gradient = self.encrypter.recursive_decrypt(
  692. encrypted_guest_weight_gradient)
  693. noise_weight_gradient = self.rng_generator.generate_random_number(
  694. (self.input_shape, self.output_unit))
  695. decrypted_guest_weight_gradient += noise_weight_gradient / self.learning_rate
  696. self.send_guest_decrypted_weight_gradient_to_guest(
  697. decrypted_guest_weight_gradient, epoch, batch)
  698. LOGGER.info(
  699. "encrypt acc_noise of epoch {} batch {}".format(
  700. epoch, batch))
  701. encrypted_acc_noise = self.encrypter.recursive_encrypt(self.acc_noise)
  702. self.send_encrypted_acc_noise_to_guest(
  703. encrypted_acc_noise, epoch, batch)
  704. self.acc_noise += noise_weight_gradient
  705. host_input_gradient = PaillierTensor(
  706. self.get_host_backward_from_guest(epoch, batch))
  707. host_input_gradient = host_input_gradient.decrypt(self.encrypter)
  708. if self.fixed_point_encoder:
  709. host_input_gradient = host_input_gradient.decode(
  710. self.fixed_point_encoder).numpy()
  711. else:
  712. host_input_gradient = host_input_gradient.numpy()
  713. return host_input_gradient, selective_ids
  714. """
  715. Communication Function
  716. """
  717. def send_backward_select_info(self, epoch, batch):
  718. selective_ids, do_backward = self.transfer_variable.selective_info.get(
  719. idx=0, suffix=(epoch, batch,))
  720. return selective_ids, do_backward
  721. def send_encrypted_acc_noise_to_guest(
  722. self, encrypted_acc_noise, epoch, batch):
  723. self.transfer_variable.encrypted_acc_noise.remote(encrypted_acc_noise,
  724. idx=0,
  725. role=consts.GUEST,
  726. suffix=(epoch, batch,))
  727. def get_interactive_layer_output_unit(self):
  728. return self.transfer_variable.interactive_layer_output_unit.get(idx=0)
  729. def get_guest_encrypted_weight_gradient_from_guest(self, epoch, batch):
  730. encrypted_guest_weight_gradient = self.transfer_variable.encrypted_guest_weight_gradient.get(
  731. idx=0, suffix=(epoch, batch,))
  732. return encrypted_guest_weight_gradient
  733. def get_interactive_layer_drop_out_table(self, epoch, batch):
  734. return self.transfer_variable.drop_out_table.get(
  735. idx=0, suffix=(epoch, batch,))
  736. def send_forward_to_guest(self, encrypted_host_input, epoch, batch, train):
  737. self.transfer_variable.encrypted_host_forward.remote(
  738. encrypted_host_input, idx=0, role=consts.GUEST, suffix=(epoch, batch, train))
  739. def send_guest_decrypted_weight_gradient_to_guest(
  740. self, decrypted_guest_weight_gradient, epoch, batch):
  741. self.transfer_variable.decrypted_guest_weight_gradient.remote(
  742. decrypted_guest_weight_gradient, idx=0, role=consts.GUEST, suffix=(epoch, batch,))
  743. def get_host_backward_from_guest(self, epoch, batch):
  744. host_backward = self.transfer_variable.host_backward.get(
  745. idx=0, suffix=(epoch, batch,))
  746. return host_backward
  747. def get_guest_encrypted_forward_from_guest(self, epoch, batch):
  748. encrypted_guest_forward = self.transfer_variable.encrypted_guest_forward.get(
  749. idx=0, suffix=(epoch, batch,))
  750. return encrypted_guest_forward
  751. def send_decrypted_guest_forward_with_noise_to_guest(
  752. self, decrypted_guest_forward_with_noise, epoch, batch):
  753. self.transfer_variable.decrypted_guest_forward.remote(
  754. decrypted_guest_forward_with_noise,
  755. idx=0,
  756. role=consts.GUEST,
  757. suffix=(
  758. epoch,
  759. batch,
  760. ))
  761. """
  762. Encrypter
  763. """
  764. def generate_encrypter(self, param):
  765. LOGGER.info("generate encrypter")
  766. if param.encrypt_param.method.lower() == consts.PAILLIER.lower():
  767. encrypter = PaillierEncrypt()
  768. encrypter.generate_key(param.encrypt_param.key_length)
  769. else:
  770. raise NotImplementedError("encrypt method not supported yet!!!")
  771. return encrypter
  772. """
  773. Model IO
  774. """
  775. def export_model(self):
  776. interactive_layer_param = InteractiveLayerParam()
  777. interactive_layer_param.acc_noise = pickle.dumps(self.acc_noise)
  778. return interactive_layer_param
  779. def restore_model(self, interactive_layer_param):
  780. self.acc_noise = pickle.loads(interactive_layer_param.acc_noise)