1234567891011121314 |
- from mas import models
- from easyfl.tracking.evaluation import count_model_params
- def get_model(arch, tasks, pretrained=False):
- model = models.__dict__[arch](pretrained=pretrained, tasks=tasks)
- print(f"Model has {count_model_params(model)} parameters")
- try:
- print(f"Encoder has {count_model_params(model.encoder)} parameters")
- except:
- print(f"Each encoder has {count_model_params(model.encoders[0])} parameters")
- for decoder in model.task_to_decoder.values():
- print(f"Decoder has {count_model_params(decoder)} parameters")
- return model
|