key_manager.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from Crypto.PublicKey import RSA
  2. from Crypto import Random
  3. from fate_flow.db.db_models import DB, SiteKeyInfo
  4. from fate_flow.entity.types import SiteKeyName
  5. from fate_flow.settings import SITE_AUTHENTICATION, PARTY_ID
  6. def rsa_key_generate():
  7. random_generator = Random.new().read
  8. rsa = RSA.generate(2048, random_generator)
  9. private_pem = rsa.exportKey().decode()
  10. public_pem = rsa.publickey().exportKey().decode()
  11. return private_pem, public_pem
  12. class RsaKeyManager:
  13. @classmethod
  14. def init(cls):
  15. if PARTY_ID and SITE_AUTHENTICATION:
  16. if not cls.get_key(PARTY_ID, key_name=SiteKeyName.PRIVATE.value):
  17. cls.generate_key(PARTY_ID)
  18. @classmethod
  19. @DB.connection_context()
  20. def create_or_update(cls, party_id, key, key_name=SiteKeyName.PUBLIC.value):
  21. defaults = {
  22. "f_party_id": party_id,
  23. "f_key_name": key_name,
  24. "f_key": key
  25. }
  26. entity_model, status = SiteKeyInfo.get_or_create(
  27. f_party_id=party_id,
  28. f_key_name=key_name,
  29. defaults=defaults
  30. )
  31. if status is False:
  32. for key in defaults:
  33. setattr(entity_model, key, defaults[key])
  34. entity_model.save(force_insert=False)
  35. return "update success"
  36. else:
  37. return "save success"
  38. @classmethod
  39. def generate_key(cls, party_id):
  40. private_key, public_key = rsa_key_generate()
  41. cls.create_or_update(party_id, private_key, key_name=SiteKeyName.PRIVATE.value)
  42. cls.create_or_update(party_id, public_key, key_name=SiteKeyName.PUBLIC.value)
  43. @classmethod
  44. @DB.connection_context()
  45. def get_key(cls, party_id, key_name=SiteKeyName.PUBLIC.value):
  46. site_info = SiteKeyInfo.query(party_id=party_id, key_name=key_name)
  47. if site_info:
  48. return site_info[0].f_key
  49. else:
  50. return None
  51. @classmethod
  52. @DB.connection_context()
  53. def delete(cls, party_id, key_name=SiteKeyName.PUBLIC.value):
  54. site_info = SiteKeyInfo.query(party_id=party_id, key_name=key_name)
  55. if site_info:
  56. return site_info[0].delete_instance()
  57. else:
  58. return None