123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import pickle
- import numpy as np
- import torch
- from torch import autograd
- from federatedml.nn.hetero.interactive.base import InteractiveLayerGuest, InteractiveLayerHost
- from federatedml.nn.hetero.nn_component.torch_model import backward_loss
- from federatedml.nn.backend.torch.interactive import InteractiveLayer
- from federatedml.nn.backend.torch.serialization import recover_sequential_from_dict
- from federatedml.util.fixpoint_solver import FixedPointEncoder
- from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import InteractiveLayerParam
- from federatedml.secureprotol import PaillierEncrypt
- from federatedml.util import consts, LOGGER
- from federatedml.nn.hetero.interactive.utils.numpy_layer import NumpyDenseLayerGuest, NumpyDenseLayerHost
- from federatedml.secureprotol.paillier_tensor import PaillierTensor
- from federatedml.nn.hetero.nn_component.torch_model import TorchNNModel
- from federatedml.transfer_variable.base_transfer_variable import BaseTransferVariables
- from fate_arch.session import computing_session as session
- from federatedml.nn.backend.utils.rng import RandomNumberGenerator
- PLAINTEXT = False
- class HEInteractiveTransferVariable(BaseTransferVariables):
- def __init__(self, flowid=0):
- super().__init__(flowid)
- self.decrypted_guest_forward = self._create_variable(
- name='decrypted_guest_forward', src=['host'], dst=['guest'])
- self.decrypted_guest_weight_gradient = self._create_variable(
- name='decrypted_guest_weight_gradient', src=['host'], dst=['guest'])
- self.encrypted_acc_noise = self._create_variable(
- name='encrypted_acc_noise', src=['host'], dst=['guest'])
- self.encrypted_guest_forward = self._create_variable(
- name='encrypted_guest_forward', src=['guest'], dst=['host'])
- self.encrypted_guest_weight_gradient = self._create_variable(
- name='encrypted_guest_weight_gradient', src=['guest'], dst=['host'])
- self.encrypted_host_forward = self._create_variable(
- name='encrypted_host_forward', src=['host'], dst=['guest'])
- self.host_backward = self._create_variable(
- name='host_backward', src=['guest'], dst=['host'])
- self.selective_info = self._create_variable(
- name="selective_info", src=["guest"], dst=["host"])
- self.drop_out_info = self._create_variable(
- name="drop_out_info", src=["guest"], dst=["host"])
- self.drop_out_table = self._create_variable(
- name="drop_out_table", src=["guest"], dst=["host"])
- self.interactive_layer_output_unit = self._create_variable(
- name="interactive_layer_output_unit", src=["guest"], dst=["host"])
- class DropOut(object):
- def __init__(self, rate, noise_shape):
- self._keep_rate = rate
- self._noise_shape = noise_shape
- self._batch_size = noise_shape[0]
- self._mask = None
- self._partition = None
- self._mask_table = None
- self._select_mask_table = None
- self._do_backward_select = False
- self._mask_table_cache = {}
- def forward(self, X):
- if X.shape == self._mask.shape:
- forward_x = X * self._mask / self._keep_rate
- else:
- forward_x = X * self._mask[0: len(X)] / self._keep_rate
- return forward_x
- def backward(self, grad):
- if self._do_backward_select:
- self._mask = self._select_mask_table[0: grad.shape[0]]
- self._select_mask_table = self._select_mask_table[grad.shape[0]:]
- return grad * self._mask / self._keep_rate
- else:
- if grad.shape == self._mask.shape:
- return grad * self._mask / self._keep_rate
- else:
- return grad * self._mask[0: grad.shape[0]] / self._keep_rate
- def generate_mask(self):
- self._mask = np.random.uniform(
- low=0, high=1, size=self._noise_shape) < self._keep_rate
- def generate_mask_table(self, shape):
- # generate mask table according to samples shape, because in some
- # batches, sample_num < batch_size
- if shape == self._noise_shape:
- _mask_table = session.parallelize(
- self._mask, include_key=False, partition=self._partition)
- else:
- _mask_table = session.parallelize(
- self._mask[0: shape[0]], include_key=False, partition=self._partition)
- return _mask_table
- def set_partition(self, partition):
- self._partition = partition
- def select_backward_sample(self, select_ids):
- select_mask_table = self._mask[np.array(select_ids)]
- if self._select_mask_table is not None:
- self._select_mask_table = np.vstack(
- (self._select_mask_table, select_mask_table))
- else:
- self._select_mask_table = select_mask_table
- def do_backward_select_strategy(self):
- self._do_backward_select = True
- class HEInteractiveLayerGuest(InteractiveLayerGuest):
- def __init__(self, params=None, layer_config=None, host_num=1):
- super(HEInteractiveLayerGuest, self).__init__(params)
- # transfer var
- self.host_num = host_num
- self.layer_config = layer_config
- self.transfer_variable = HEInteractiveTransferVariable()
- self.plaintext = PLAINTEXT
- self.layer_config = layer_config
- self.host_input_shapes = []
- self.rng_generator = RandomNumberGenerator()
- self.learning_rate = params.interactive_layer_lr
- # cached tensor
- self.guest_tensor = None
- self.host_tensors = None
- self.dense_output_data_require_grad = None
- self.activation_out_require_grad = None
- # model
- self.model: InteractiveLayer = None
- self.guest_model = None
- self.host_model_list = []
- self.batch_size = None
- self.partitions = 0
- self.do_backward_select_strategy = False
- self.optimizer = None
- # drop out
- self.drop_out_initiated = False
- self.drop_out = None
- self.drop_out_keep_rate = None
- self.fixed_point_encoder = None if params.floating_point_precision is None else FixedPointEncoder(
- 2 ** params.floating_point_precision)
- self.send_output_unit = False
- # float64
- self.float64 = False
- """
- Init functions
- """
- def set_flow_id(self, flow_id):
- self.transfer_variable.set_flowid(flow_id)
- def set_backward_select_strategy(self):
- self.do_backward_select_strategy = True
- def set_batch(self, batch_size):
- self.batch_size = batch_size
- def set_partition(self, partition):
- self.partitions = partition
- def _build_model(self):
- if self.model is None:
- raise ValueError('torch interactive model is not initialized!')
- for i in range(self.host_num):
- host_model = NumpyDenseLayerHost()
- host_model.build(self.model.host_model[i])
- host_model.set_learning_rate(self.learning_rate)
- self.host_model_list.append(host_model)
- self.guest_model = NumpyDenseLayerGuest()
- self.guest_model.build(self.model.guest_model)
- self.guest_model.set_learning_rate(self.learning_rate)
- if self.do_backward_select_strategy:
- self.guest_model.set_backward_selective_strategy()
- self.guest_model.set_batch(self.batch_size)
- for host_model in self.host_model_list:
- host_model.set_backward_selective_strategy()
- host_model.set_batch(self.batch_size)
- """
- Drop out functions
- """
- def init_drop_out_parameter(self):
- if isinstance(self.model.param_dict['dropout'], float):
- self.drop_out_keep_rate = 1 - self.model.param_dict['dropout']
- else:
- self.drop_out_keep_rate = -1
- self.transfer_variable.drop_out_info.remote(
- self.drop_out_keep_rate, idx=-1, suffix=('dropout_rate', ))
- self.drop_out_initiated = True
- def _create_drop_out(self, shape):
- if self.drop_out_keep_rate and self.drop_out_keep_rate != 1 and self.drop_out_keep_rate > 0:
- if not self.drop_out:
- self.drop_out = DropOut(
- noise_shape=shape, rate=self.drop_out_keep_rate)
- self.drop_out.set_partition(self.partitions)
- if self.do_backward_select_strategy:
- self.drop_out.do_backward_select_strategy()
- self.drop_out.generate_mask()
- @staticmethod
- def expand_columns(tensor, keep_array):
- shape = keep_array.shape
- tensor = np.reshape(tensor, (tensor.size,))
- keep = np.reshape(keep_array, (keep_array.size,))
- ret_tensor = []
- idx = 0
- for x in keep:
- if x == 0:
- ret_tensor.append(0)
- else:
- ret_tensor.append(tensor[idx])
- idx += 1
- return np.reshape(np.array(ret_tensor), shape)
- """
- Plaintext forward/backward, these interfaces are for testing
- """
- def plaintext_forward(self, guest_input, epoch=0, batch=0, train=True):
- if self.model is None:
- self.model = recover_sequential_from_dict(self.layer_config)[0]
- if self.float64:
- self.model.type(torch.float64)
- if self.optimizer is None:
- self.optimizer = torch.optim.SGD(
- params=self.model.parameters(), lr=self.learning_rate)
- if train:
- self.model.train()
- else:
- self.model.eval()
- with torch.no_grad():
- guest_tensor = torch.from_numpy(guest_input)
- host_inputs = self.get_forward_from_host(
- epoch, batch, train, idx=-1)
- host_tensors = [torch.from_numpy(arr) for arr in host_inputs]
- interactive_out = self.model(guest_tensor, host_tensors)
- self.guest_tensor = guest_tensor
- self.host_tensors = host_tensors
- return interactive_out.cpu().detach().numpy()
- def plaintext_backward(self, output_gradient, epoch, batch):
- # compute input gradient
- self.guest_tensor: torch.Tensor = self.guest_tensor.requires_grad_(True)
- for tensor in self.host_tensors:
- tensor.requires_grad_(True)
- out = self.model(self.guest_tensor, self.host_tensors)
- loss = backward_loss(out, torch.from_numpy(output_gradient))
- backward_list = [self.guest_tensor]
- backward_list.extend(self.host_tensors)
- ret_grad = autograd.grad(loss, backward_list)
- # update model
- self.guest_tensor: torch.Tensor = self.guest_tensor.requires_grad_(False)
- for tensor in self.host_tensors:
- tensor.requires_grad_(False)
- self.optimizer.zero_grad()
- out = self.model(self.guest_tensor, self.host_tensors)
- loss = backward_loss(out, torch.from_numpy(output_gradient))
- loss.backward()
- self.optimizer.step()
- self.guest_tensor, self.host_tensors = None, None
- for idx, host_grad in enumerate(ret_grad[1:]):
- self.send_host_backward_to_host(host_grad, epoch, batch, idx=idx)
- return ret_grad[0]
- """
- Activation forward & backward
- """
- def activation_forward(self, dense_out, with_grad=True):
- if with_grad:
- if (self.dense_output_data_require_grad is not None) or (
- self.activation_out_require_grad is not None):
- raise ValueError(
- 'torch forward error, related required grad tensors are not freed')
- self.dense_output_data_require_grad = dense_out.requires_grad_(
- True)
- activation_out_ = self.model.activation(
- self.dense_output_data_require_grad)
- self.activation_out_require_grad = activation_out_
- else:
- with torch.no_grad():
- activation_out_ = self.model.activation(dense_out)
- return activation_out_.cpu().detach().numpy()
- def activation_backward(self, output_gradients):
- if self.activation_out_require_grad is None and self.dense_output_data_require_grad is None:
- raise ValueError('related grad is None, cannot compute backward')
- loss = backward_loss(
- self.activation_out_require_grad,
- torch.Tensor(output_gradients))
- activation_backward_grad = torch.autograd.grad(
- loss, self.dense_output_data_require_grad)
- self.activation_out_require_grad = None
- self.dense_output_data_require_grad = None
- return activation_backward_grad[0].cpu().detach().numpy()
- """
- Forward & Backward
- """
- def print_log(self, descr, epoch, batch, train):
- if train:
- LOGGER.info("{} epoch {} batch {}"
- "".format(descr, epoch, batch))
- else:
- LOGGER.info("predicting, {} pred iteration {} batch {}"
- "".format(descr, epoch, batch))
- def forward_interactive(
- self,
- encrypted_host_input,
- epoch,
- batch,
- train=True):
- self.print_log(
- 'get encrypted dense output of host model of',
- epoch,
- batch,
- train)
- mask_table_list = []
- guest_nosies = []
- host_idx = 0
- for model, host_bottom_input in zip(
- self.host_model_list, encrypted_host_input):
- encrypted_fw = model(host_bottom_input, self.fixed_point_encoder)
- mask_table = None
- if train:
- self._create_drop_out(encrypted_fw.shape)
- if self.drop_out:
- mask_table = self.drop_out.generate_mask_table(
- encrypted_fw.shape)
- if mask_table:
- encrypted_fw = encrypted_fw.select_columns(mask_table)
- mask_table_list.append(mask_table)
- guest_forward_noise = self.rng_generator.fast_generate_random_number(
- encrypted_fw.shape, encrypted_fw.partitions, keep_table=mask_table)
- if self.fixed_point_encoder:
- encrypted_fw += guest_forward_noise.encode(
- self.fixed_point_encoder)
- else:
- encrypted_fw += guest_forward_noise
- guest_nosies.append(guest_forward_noise)
- self.send_guest_encrypted_forward_output_with_noise_to_host(
- encrypted_fw.get_obj(), epoch, batch, idx=host_idx)
- if mask_table:
- self.send_interactive_layer_drop_out_table(
- mask_table, epoch, batch, idx=host_idx)
- host_idx += 1
- # get list from hosts
- decrypted_dense_outputs = self.get_guest_decrypted_forward_from_host(
- epoch, batch, idx=-1)
- merge_output = None
- for idx, (outputs, noise) in enumerate(
- zip(decrypted_dense_outputs, guest_nosies)):
- out = PaillierTensor(outputs) - noise
- if len(mask_table_list) != 0:
- out = PaillierTensor(
- out.get_obj().join(
- mask_table_list[idx],
- self.expand_columns))
- if merge_output is None:
- merge_output = out
- else:
- merge_output = merge_output + out
- return merge_output
- def forward(self, x, epoch: int, batch: int, train: bool = True, **kwargs):
- self.print_log(
- 'interactive layer running forward propagation',
- epoch,
- batch,
- train)
- if self.plaintext:
- return self.plaintext_forward(x, epoch, batch, train)
- if self.model is None:
- self.model = recover_sequential_from_dict(self.layer_config)[0]
- LOGGER.debug('interactive model is {}'.format(self.model))
- # for multi host cases
- LOGGER.debug(
- 'host num is {}, len host model {}'.format(
- self.host_num, len(
- self.model.host_model)))
- assert self.host_num == len(self.model.host_model), 'host number is {}, but host linear layer number is {},' \
- 'please check your interactive configuration, make sure' \
- ' that host layer number equals to host number' \
- .format(self.host_num, len(self.model.host_model))
- if self.float64:
- self.model.type(torch.float64)
- if train and not self.drop_out_initiated:
- self.init_drop_out_parameter()
- host_inputs = self.get_forward_from_host(epoch, batch, train, idx=-1)
- host_bottom_inputs_tensor = []
- host_input_shapes = []
- for i in host_inputs:
- pt = PaillierTensor(i)
- host_bottom_inputs_tensor.append(pt)
- host_input_shapes.append(pt.shape[1])
- self.model.lazy_to_linear(x.shape[1], host_dims=host_input_shapes)
- self.host_input_shapes = host_input_shapes
- if self.guest_model is None:
- LOGGER.info("building interactive layers' training model")
- self._build_model()
- if not self.partitions:
- self.partitions = host_bottom_inputs_tensor[0].partitions
- if not self.send_output_unit:
- self.send_output_unit = True
- for idx in range(self.host_num):
- self.send_interactive_layer_output_unit(
- self.host_model_list[idx].output_shape[0], idx=idx)
- guest_output = self.guest_model(x)
- host_output = self.forward_interactive(
- host_bottom_inputs_tensor, epoch, batch, train)
- if guest_output is not None:
- dense_output_data = host_output + \
- PaillierTensor(guest_output, partitions=self.partitions)
- else:
- dense_output_data = host_output
- self.print_log(
- "start to get interactive layer's activation output of",
- epoch,
- batch,
- train)
- if self.float64: # result after encrypt calculation is float 64
- dense_out = torch.from_numpy(dense_output_data.numpy())
- else:
- dense_out = torch.Tensor(
- dense_output_data.numpy()) # convert to float32
- if self.do_backward_select_strategy:
- for h in self.host_model_list:
- h.activation_input = dense_out.cpu().detach().numpy()
- # if is not backward strategy, can compute grad directly
- if not train or self.do_backward_select_strategy:
- with_grad = False
- else:
- with_grad = True
- activation_out = self.activation_forward(
- dense_out, with_grad=with_grad)
- if train and self.drop_out:
- return self.drop_out.forward(activation_out)
- return activation_out
- def backward_interactive(
- self,
- host_model,
- activation_gradient,
- epoch,
- batch,
- host_idx):
- LOGGER.info(
- "get encrypted weight gradient of epoch {} batch {}".format(
- epoch, batch))
- encrypted_weight_gradient = host_model.get_weight_gradient(
- activation_gradient, encoder=self.fixed_point_encoder)
- if self.fixed_point_encoder:
- encrypted_weight_gradient = self.fixed_point_encoder.decode(
- encrypted_weight_gradient)
- noise_w = self.rng_generator.generate_random_number(
- encrypted_weight_gradient.shape)
- self.transfer_variable.encrypted_guest_weight_gradient.remote(
- encrypted_weight_gradient +
- noise_w,
- role=consts.HOST,
- idx=host_idx,
- suffix=(
- epoch,
- batch,
- ))
- LOGGER.info(
- "get decrypted weight graident of epoch {} batch {}".format(
- epoch, batch))
- decrypted_weight_gradient = self.transfer_variable.decrypted_guest_weight_gradient.get(
- idx=host_idx, suffix=(epoch, batch,))
- decrypted_weight_gradient -= noise_w
- encrypted_acc_noise = self.get_encrypted_acc_noise_from_host(
- epoch, batch, idx=host_idx)
- return decrypted_weight_gradient, encrypted_acc_noise
- def backward(self, error, epoch: int, batch: int, selective_ids=None):
- if self.plaintext:
- return self.plaintext_backward(error, epoch, batch)
- if selective_ids:
- for host_model in self.host_model_list:
- host_model.select_backward_sample(selective_ids)
- self.guest_model.select_backward_sample(selective_ids)
- if self.drop_out:
- self.drop_out.select_backward_sample(selective_ids)
- if self.do_backward_select_strategy:
- # send to all host
- self.send_backward_select_info(
- selective_ids, len(error), epoch, batch, -1)
- if len(error) > 0:
- LOGGER.debug(
- "interactive layer start backward propagation of epoch {} batch {}".format(
- epoch, batch))
- if not self.do_backward_select_strategy:
- activation_gradient = self.activation_backward(error)
- else:
- act_input = self.host_model_list[0].get_selective_activation_input(
- )
- _ = self.activation_forward(torch.from_numpy(act_input), True)
- activation_gradient = self.activation_backward(error)
- if self.drop_out:
- activation_gradient = self.drop_out.backward(
- activation_gradient)
- LOGGER.debug(
- "interactive layer update guest weight of epoch {} batch {}".format(
- epoch, batch))
- # update guest model
- guest_input_gradient = self.update_guest(activation_gradient)
- LOGGER.debug('update host model weights')
- for idx, host_model in enumerate(self.host_model_list):
- # update host models
- host_weight_gradient, acc_noise = self.backward_interactive(
- host_model, activation_gradient, epoch, batch, host_idx=idx)
- host_input_gradient = self.update_host(
- host_model, activation_gradient, host_weight_gradient, acc_noise)
- self.send_host_backward_to_host(
- host_input_gradient.get_obj(), epoch, batch, idx=idx)
- return guest_input_gradient
- else:
- return []
- """
- Model update
- """
- def update_guest(self, activation_gradient):
- input_gradient = self.guest_model.get_input_gradient(
- activation_gradient)
- weight_gradient = self.guest_model.get_weight_gradient(
- activation_gradient)
- self.guest_model.update_weight(weight_gradient)
- self.guest_model.update_bias(activation_gradient)
- return input_gradient
- def update_host(
- self,
- host_model,
- activation_gradient,
- weight_gradient,
- acc_noise):
- activation_gradient_tensor = PaillierTensor(
- activation_gradient, partitions=self.partitions)
- input_gradient = host_model.get_input_gradient(
- activation_gradient_tensor, acc_noise, encoder=self.fixed_point_encoder)
- host_model.update_weight(weight_gradient)
- host_model.update_bias(activation_gradient)
- return input_gradient
- """
- Communication functions
- """
- def send_interactive_layer_output_unit(self, shape, idx=0):
- self.transfer_variable.interactive_layer_output_unit.remote(
- shape, role=consts.HOST, idx=idx)
- def send_backward_select_info(
- self,
- selective_ids,
- gradient_len,
- epoch,
- batch,
- idx):
- self.transfer_variable.selective_info.remote(
- (selective_ids, gradient_len), role=consts.HOST, idx=idx, suffix=(
- epoch, batch,))
- def send_host_backward_to_host(self, host_error, epoch, batch, idx):
- self.transfer_variable.host_backward.remote(host_error,
- role=consts.HOST,
- idx=idx,
- suffix=(epoch, batch,))
- def get_forward_from_host(self, epoch, batch, train, idx=0):
- return self.transfer_variable.encrypted_host_forward.get(
- idx=idx, suffix=(epoch, batch, train))
- def send_guest_encrypted_forward_output_with_noise_to_host(
- self, encrypted_guest_forward_with_noise, epoch, batch, idx):
- return self.transfer_variable.encrypted_guest_forward.remote(
- encrypted_guest_forward_with_noise,
- role=consts.HOST,
- idx=idx,
- suffix=(
- epoch,
- batch,
- ))
- def send_interactive_layer_drop_out_table(
- self, mask_table, epoch, batch, idx):
- return self.transfer_variable.drop_out_table.remote(
- mask_table, role=consts.HOST, idx=idx, suffix=(epoch, batch,))
- def get_guest_decrypted_forward_from_host(self, epoch, batch, idx=0):
- return self.transfer_variable.decrypted_guest_forward.get(
- idx=idx, suffix=(epoch, batch,))
- def get_encrypted_acc_noise_from_host(self, epoch, batch, idx=0):
- return self.transfer_variable.encrypted_acc_noise.get(
- idx=idx, suffix=(epoch, batch,))
- """
- Model IO
- """
- def transfer_np_model_to_torch_interactive_layer(self):
- self.model = self.model.cpu()
- if self.guest_model is not None:
- guest_weight = self.guest_model.get_weight()
- model: torch.nn.Linear = self.model.guest_model
- model.weight.data.copy_(torch.Tensor(guest_weight))
- if self.guest_model.bias is not None:
- model.bias.data.copy_(torch.Tensor(self.guest_model.bias))
- for host_np_model, torch_model in zip(
- self.host_model_list, self.model.host_model):
- host_weight = host_np_model.get_weight()
- torch_model.weight.data.copy_(torch.Tensor(host_weight))
- if host_np_model.bias is not None:
- torch_model.bias.data.copy_(torch.Tensor(torch_model.bias))
- def export_model(self):
- self.transfer_np_model_to_torch_interactive_layer()
- interactive_layer_param = InteractiveLayerParam()
- interactive_layer_param.interactive_guest_saved_model_bytes = TorchNNModel.get_model_bytes(
- self.model)
- interactive_layer_param.host_input_shape.extend(self.host_input_shapes)
- return interactive_layer_param
- def restore_model(self, interactive_layer_param):
- self.host_input_shapes = list(interactive_layer_param.host_input_shape)
- self.model = TorchNNModel.recover_model_bytes(
- interactive_layer_param.interactive_guest_saved_model_bytes)
- self._build_model()
- class HEInteractiveLayerHost(InteractiveLayerHost):
- def __init__(self, params):
- super(HEInteractiveLayerHost, self).__init__(params)
- self.plaintext = PLAINTEXT
- self.acc_noise = None
- self.learning_rate = params.interactive_layer_lr
- self.encrypter = self.generate_encrypter(params)
- self.transfer_variable = HEInteractiveTransferVariable()
- self.partitions = 1
- self.input_shape = None
- self.output_unit = None
- self.rng_generator = RandomNumberGenerator()
- self.do_backward_select_strategy = False
- self.drop_out_init = False
- self.drop_out_keep_rate = None
- self.fixed_point_encoder = None if params.floating_point_precision is None else FixedPointEncoder(
- 2 ** params.floating_point_precision)
- self.mask_table = None
- """
- Init
- """
- def set_transfer_variable(self, transfer_variable):
- self.transfer_variable = transfer_variable
- def set_partition(self, partition):
- self.partitions = partition
- def set_backward_select_strategy(self):
- self.do_backward_select_strategy = True
- """
- Forward & Backward
- """
- def plaintext_forward(self, host_input, epoch, batch, train):
- self.send_forward_to_guest(host_input, epoch, batch, train)
- def plaintext_backward(self, epoch, batch):
- return self.get_host_backward_from_guest(epoch, batch)
- def forward(self, host_input, epoch=0, batch=0, train=True, **kwargs):
- if self.plaintext:
- self.plaintext_forward(host_input, epoch, batch, train)
- return
- if train and not self.drop_out_init:
- self.drop_out_init = True
- self.drop_out_keep_rate = self.transfer_variable.drop_out_info.get(
- 0, role=consts.GUEST, suffix=('dropout_rate', ))
- if self.drop_out_keep_rate == -1:
- self.drop_out_keep_rate = None
- LOGGER.info(
- "forward propagation: encrypt host_bottom_output of epoch {} batch {}".format(
- epoch, batch))
- host_input = PaillierTensor(host_input, partitions=self.partitions)
- encrypted_host_input = host_input.encrypt(self.encrypter)
- self.send_forward_to_guest(
- encrypted_host_input.get_obj(), epoch, batch, train)
- encrypted_guest_forward = PaillierTensor(
- self.get_guest_encrypted_forward_from_guest(epoch, batch))
- decrypted_guest_forward = encrypted_guest_forward.decrypt(
- self.encrypter)
- if self.fixed_point_encoder:
- decrypted_guest_forward = decrypted_guest_forward.decode(
- self.fixed_point_encoder)
- if self.input_shape is None:
- self.input_shape = host_input.shape[1]
- self.output_unit = self.get_interactive_layer_output_unit()
- if self.acc_noise is None:
- self.acc_noise = np.zeros((self.input_shape, self.output_unit))
- mask_table = None
- if train and self.drop_out_keep_rate and self.drop_out_keep_rate < 1:
- mask_table = self.get_interactive_layer_drop_out_table(
- epoch, batch)
- if mask_table:
- decrypted_guest_forward_with_noise = decrypted_guest_forward + \
- (host_input * self.acc_noise).select_columns(mask_table)
- self.mask_table = mask_table
- else:
- noise_part = (host_input * self.acc_noise)
- decrypted_guest_forward_with_noise = decrypted_guest_forward + noise_part
- self.send_decrypted_guest_forward_with_noise_to_guest(
- decrypted_guest_forward_with_noise.get_obj(), epoch, batch)
- def backward(self, epoch, batch):
- if self.plaintext:
- return self.plaintext_backward(epoch, batch), []
- do_backward = True
- selective_ids = []
- if self.do_backward_select_strategy:
- selective_ids, do_backward = self.send_backward_select_info(
- epoch, batch)
- if not do_backward:
- return [], selective_ids
- encrypted_guest_weight_gradient = self.get_guest_encrypted_weight_gradient_from_guest(
- epoch, batch)
- LOGGER.info(
- "decrypt weight gradient of epoch {} batch {}".format(
- epoch, batch))
- decrypted_guest_weight_gradient = self.encrypter.recursive_decrypt(
- encrypted_guest_weight_gradient)
- noise_weight_gradient = self.rng_generator.generate_random_number(
- (self.input_shape, self.output_unit))
- decrypted_guest_weight_gradient += noise_weight_gradient / self.learning_rate
- self.send_guest_decrypted_weight_gradient_to_guest(
- decrypted_guest_weight_gradient, epoch, batch)
- LOGGER.info(
- "encrypt acc_noise of epoch {} batch {}".format(
- epoch, batch))
- encrypted_acc_noise = self.encrypter.recursive_encrypt(self.acc_noise)
- self.send_encrypted_acc_noise_to_guest(
- encrypted_acc_noise, epoch, batch)
- self.acc_noise += noise_weight_gradient
- host_input_gradient = PaillierTensor(
- self.get_host_backward_from_guest(epoch, batch))
- host_input_gradient = host_input_gradient.decrypt(self.encrypter)
- if self.fixed_point_encoder:
- host_input_gradient = host_input_gradient.decode(
- self.fixed_point_encoder).numpy()
- else:
- host_input_gradient = host_input_gradient.numpy()
- return host_input_gradient, selective_ids
- """
- Communication Function
- """
- def send_backward_select_info(self, epoch, batch):
- selective_ids, do_backward = self.transfer_variable.selective_info.get(
- idx=0, suffix=(epoch, batch,))
- return selective_ids, do_backward
- def send_encrypted_acc_noise_to_guest(
- self, encrypted_acc_noise, epoch, batch):
- self.transfer_variable.encrypted_acc_noise.remote(encrypted_acc_noise,
- idx=0,
- role=consts.GUEST,
- suffix=(epoch, batch,))
- def get_interactive_layer_output_unit(self):
- return self.transfer_variable.interactive_layer_output_unit.get(idx=0)
- def get_guest_encrypted_weight_gradient_from_guest(self, epoch, batch):
- encrypted_guest_weight_gradient = self.transfer_variable.encrypted_guest_weight_gradient.get(
- idx=0, suffix=(epoch, batch,))
- return encrypted_guest_weight_gradient
- def get_interactive_layer_drop_out_table(self, epoch, batch):
- return self.transfer_variable.drop_out_table.get(
- idx=0, suffix=(epoch, batch,))
- def send_forward_to_guest(self, encrypted_host_input, epoch, batch, train):
- self.transfer_variable.encrypted_host_forward.remote(
- encrypted_host_input, idx=0, role=consts.GUEST, suffix=(epoch, batch, train))
- def send_guest_decrypted_weight_gradient_to_guest(
- self, decrypted_guest_weight_gradient, epoch, batch):
- self.transfer_variable.decrypted_guest_weight_gradient.remote(
- decrypted_guest_weight_gradient, idx=0, role=consts.GUEST, suffix=(epoch, batch,))
- def get_host_backward_from_guest(self, epoch, batch):
- host_backward = self.transfer_variable.host_backward.get(
- idx=0, suffix=(epoch, batch,))
- return host_backward
- def get_guest_encrypted_forward_from_guest(self, epoch, batch):
- encrypted_guest_forward = self.transfer_variable.encrypted_guest_forward.get(
- idx=0, suffix=(epoch, batch,))
- return encrypted_guest_forward
- def send_decrypted_guest_forward_with_noise_to_guest(
- self, decrypted_guest_forward_with_noise, epoch, batch):
- self.transfer_variable.decrypted_guest_forward.remote(
- decrypted_guest_forward_with_noise,
- idx=0,
- role=consts.GUEST,
- suffix=(
- epoch,
- batch,
- ))
- """
- Encrypter
- """
- def generate_encrypter(self, param):
- LOGGER.info("generate encrypter")
- if param.encrypt_param.method.lower() == consts.PAILLIER.lower():
- encrypter = PaillierEncrypt()
- encrypter.generate_key(param.encrypt_param.key_length)
- else:
- raise NotImplementedError("encrypt method not supported yet!!!")
- return encrypter
- """
- Model IO
- """
- def export_model(self):
- interactive_layer_param = InteractiveLayerParam()
- interactive_layer_param.acc_noise = pickle.dumps(self.acc_noise)
- return interactive_layer_param
- def restore_model(self, interactive_layer_param):
- self.acc_noise = pickle.loads(interactive_layer_param.acc_noise)
|