run.sh 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. mkdir -p log
  2. mkdir -p log/mas
  3. now=$(date +"%Y%m%d_%H%M%S")
  4. root_dir=/mnt/lustre/$(whoami)
  5. project_dir=$root_dir/easyfl/applications/mas
  6. data_dir=$root_dir/datasets/taskonomy_datasets
  7. client_file=$project_dir/clients.txt
  8. export PYTHONPATH=$PYTHONPATH:${pwd}
  9. while [[ "$#" -gt 0 ]]; do
  10. case $1 in
  11. -p) partition="$2"; shift ;;
  12. -t) tasks="$2"; shift ;;
  13. -a) arch="$2"; shift ;;
  14. -e) local_epoch="$2"; shift ;;
  15. -k) clients_per_round="$2"; shift ;;
  16. -b) batch_size="$2"; shift ;;
  17. -r) rounds="$2"; shift ;;
  18. -lr) lr="$2"; shift ;;
  19. -lrt) lr_type="$2"; shift ;;
  20. -te) test_every="$2"; shift ;;
  21. -se) save_model_every="$2"; shift ;;
  22. -gpus) gpus="$2"; shift ;;
  23. -count) run_count="$2"; shift ;;
  24. -port) dist_port="$2"; shift ;;
  25. -tag) tag="$2"; shift ;;
  26. -tag_step) tag_step="$2"; shift ;;
  27. -what) what="$2"; shift ;;
  28. -client_id) client_id="$2"; shift ;;
  29. -agg_strategy) agg_strategy="$2"; shift ;;
  30. -pretrained) pretrained="$2"; shift ;;
  31. -pt) pretrained_tasks="$2"; shift ;;
  32. -decoder) decoder="$2"; shift ;;
  33. -half) half="$2"; shift ;;
  34. *) echo "Unknown parameter passed: $1"; exit 1 ;;
  35. esac
  36. shift
  37. done
  38. if [ -z "${partition}" ]
  39. then
  40. partition=partition
  41. fi
  42. if [ -z "${tasks}" ]
  43. then
  44. tasks=""
  45. fi
  46. if [ -z "${arch}" ]
  47. then
  48. arch=xception # options: xception, resnet18
  49. fi
  50. if [ -z "${local_epoch}" ]
  51. then
  52. local_epoch=5
  53. fi
  54. if [ -z "${clients_per_round}" ]
  55. then
  56. clients_per_round=5
  57. fi
  58. if [ -z "${batch_size}" ]
  59. then
  60. batch_size=64
  61. fi
  62. if [ -z "${lr}" ]
  63. then
  64. lr=0.1
  65. fi
  66. if [ -z "${lr_type}" ]
  67. then
  68. lr_type=poly
  69. fi
  70. if [ -z "${rounds}" ]
  71. then
  72. rounds=100
  73. fi
  74. if [ -z "${test_every}" ]
  75. then
  76. test_every=1
  77. fi
  78. if [ -z "${save_model_every}" ]
  79. then
  80. save_model_every=1
  81. fi
  82. if [ -z "${gpus}" ]
  83. then
  84. gpus=1
  85. fi
  86. if [ -z "${dist_port}" ]
  87. then
  88. dist_port=23344
  89. fi
  90. # Whether use task affinity grouping (lookahead)
  91. if [ -z "${tag}" ]
  92. then
  93. tag='y'
  94. fi
  95. # Lookahead step
  96. if [ -z "${tag_step}" ]
  97. then
  98. tag_step=10
  99. fi
  100. if [ -z "${run_count}" ]
  101. then
  102. run_count=0
  103. fi
  104. if [ -z "${client_id}" ]
  105. then
  106. client_id='NA'
  107. fi
  108. if [ -z "${agg_strategy}" ]
  109. then
  110. agg_strategy='FedAvg'
  111. fi
  112. if [ -z "${pretrained_tasks}" ]
  113. then
  114. pretrained_tasks='sdnkt'
  115. fi
  116. use_pretrained='y'
  117. if [ -z "${pretrained}" ]
  118. then
  119. pretrained='n'
  120. use_pretrained='n'
  121. pretrained_tasks='n'
  122. fi
  123. if [ -z "${decoder}" ]
  124. then
  125. decoder='y'
  126. fi
  127. if [ -z "${half}" ]
  128. then
  129. half='n'
  130. fi
  131. job_name=mas-${tasks}-${arch}-b${batch_size}-${lr_type}lr${lr}-${agg_strategy}-tag-${tag}-${tag_step}-e${local_epoch}-n${clients_per_round}-r${rounds}-te${test_every}-se${save_model_every}-pretrained-${use_pretrained}-${pretrained_tasks}-${what}-${run_count}
  132. echo ${job_name}
  133. srun -u --partition=${partition} --job-name=${job_name} \
  134. -n${gpus} --gres=gpu:${gpus} --ntasks-per-node=${gpus} \
  135. python ${project_dir}/main.py --data_dir ${data_dir} --arch ${arch} --client_file ${client_file} \
  136. --task_id ${job_name} --tasks ${tasks} --rotate_loss --batch_size ${batch_size} --lr ${lr} --lr_type ${lr_type} \
  137. --local_epoch ${local_epoch} --clients_per_round ${clients_per_round} --rounds ${rounds} \
  138. --test_every ${test_every} --save_model_every ${save_model_every} --random_selection --lookahead ${tag} --lookahead_step ${tag_step} \
  139. --dist_port ${dist_port} --run_count ${run_count} --load_decoder ${decoder} --half ${half} \
  140. --aggregation_strategy ${agg_strategy} --pretrained ${pretrained} --pretrained_tasks ${pretrained_tasks} \
  141. --client_id ${client_id} 2>&1 | tee log/mas/${job_name}.log &