common.py 946 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch as t
  2. import numpy as np
  3. import tempfile
  4. ML_PATH = 'federatedml.nn'
  5. HOMOMODELMETA = "HomoNNMeta"
  6. HOMOMODELPARAM = "HomoNNParam"
  7. def global_seed(seed):
  8. # set random seed of torch
  9. t.manual_seed(seed)
  10. t.cuda.manual_seed_all(seed)
  11. t.backends.cudnn.deterministic = True
  12. def get_homo_model_dict(param, meta):
  13. return {HOMOMODELPARAM: param, # param
  14. HOMOMODELMETA: meta} # meta
  15. def get_homo_param_meta(model_dict):
  16. return model_dict.get(HOMOMODELPARAM), model_dict.get(HOMOMODELMETA)
  17. # read model from model bytes
  18. def recover_model_bytes(model_bytes):
  19. with tempfile.TemporaryFile() as f:
  20. f.write(model_bytes)
  21. f.seek(0)
  22. model_dict = t.load(f)
  23. return model_dict
  24. def get_torch_model_bytes(model_dict):
  25. with tempfile.TemporaryFile() as f:
  26. t.save(model_dict, f)
  27. f.seek(0)
  28. model_saved_bytes = f.read()
  29. return model_saved_bytes