model.py 666 B

123456789101112131415161718192021222324
  1. import importlib
  2. import logging
  3. from os import path
  4. from torch import nn
  5. logger = logging.getLogger(__name__)
  6. class BaseModel(nn.Module):
  7. def __init__(self):
  8. super(BaseModel, self).__init__()
  9. def load_model(model_name: str):
  10. dir_path = path.dirname(path.realpath(__file__))
  11. model_file = path.join(dir_path, "{}.py".format(model_name))
  12. if not path.exists(model_file):
  13. logger.error("Please specify a valid model.")
  14. model_path = "easyfl.models.{}".format(model_name)
  15. model_lib = importlib.import_module(model_path)
  16. model = getattr(model_lib, "Model")
  17. # TODO: maybe return the model class initiator
  18. return model