model.py 585 B

1234567891011121314
  1. from mas import models
  2. from easyfl.tracking.evaluation import count_model_params
  3. def get_model(arch, tasks, pretrained=False):
  4. model = models.__dict__[arch](pretrained=pretrained, tasks=tasks)
  5. print(f"Model has {count_model_params(model)} parameters")
  6. try:
  7. print(f"Encoder has {count_model_params(model.encoder)} parameters")
  8. except:
  9. print(f"Each encoder has {count_model_params(model.encoders[0])} parameters")
  10. for decoder in model.task_to_decoder.values():
  11. print(f"Decoder has {count_model_params(decoder)} parameters")
  12. return model