client.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import transformers
  2. import os
  3. from datasets import load_dataset
  4. import copy
  5. from collections import OrderedDict
  6. import torch
  7. from peft import (
  8. get_peft_model_state_dict,
  9. set_peft_model_state_dict,
  10. )
  11. class GeneralClient:
  12. def __init__(self, client_id, model, data_path, output_dir):
  13. self.client_id = client_id
  14. self.model = model
  15. self.local_data_path = os.path.join(data_path, "local_training_{}.json".format(self.client_id))
  16. self.local_data = load_dataset("json", data_files=self.local_data_path)
  17. self.output_dir = output_dir
  18. self.local_output_dir = os.path.join(self.output_dir, "trainer_saved", "local_output_{}".format(self.client_id))
  19. def preprare_local_dataset(self, generate_and_tokenize_prompt, local_val_set_size):
  20. if local_val_set_size > 0:
  21. local_train_val = self.local_data["train"].train_test_split(
  22. test_size=local_val_set_size, shuffle=True, seed=42
  23. )
  24. self.local_train_dataset = (
  25. local_train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  26. )
  27. self.local_eval_dataset = (
  28. local_train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  29. )
  30. else:
  31. self.local_train_dataset = self.local_data["train"].shuffle().map(generate_and_tokenize_prompt)
  32. self.local_eval_dataset = None
  33. self.local_val_set_size = local_val_set_size
  34. def build_local_trainer(self,
  35. tokenizer,
  36. local_micro_batch_size,
  37. gradient_accumulation_steps,
  38. local_num_epochs,
  39. local_learning_rate,
  40. group_by_length,
  41. ddp):
  42. self.train_args = transformers.TrainingArguments(
  43. per_device_train_batch_size=local_micro_batch_size,
  44. gradient_accumulation_steps=gradient_accumulation_steps,
  45. warmup_steps=0,
  46. num_train_epochs=local_num_epochs,
  47. learning_rate=local_learning_rate,
  48. fp16=True,
  49. logging_steps=1,
  50. optim="adamw_torch",
  51. evaluation_strategy="steps" if self.local_val_set_size > 0 else "no",
  52. save_strategy="steps",
  53. eval_steps=200 if self.local_val_set_size > 0 else None,
  54. save_steps=200,
  55. output_dir=self.local_output_dir,
  56. save_total_limit=1,
  57. load_best_model_at_end=True if self.local_val_set_size > 0 else False,
  58. ddp_find_unused_parameters=False if ddp else None,
  59. group_by_length=group_by_length,
  60. dataloader_drop_last=False
  61. )
  62. self.local_trainer = transformers.Trainer(model=self.model,
  63. train_dataset=self.local_train_dataset,
  64. eval_dataset=self.local_eval_dataset,
  65. args=self.train_args,
  66. data_collator=transformers.DataCollatorForSeq2Seq(
  67. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  68. ),
  69. )
  70. def initiate_local_training(self):
  71. self.model.config.use_cache = False
  72. # 获取模型的参数
  73. self.params_dict_old = copy.deepcopy(
  74. OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
  75. "default" in name))
  76. self.params_dict_new = OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
  77. "default" in name)
  78. # 设置模型的 state_dict 方法
  79. self.model.state_dict = (
  80. lambda instance, *_, **__: get_peft_model_state_dict(
  81. instance, self.params_dict_new, "default"
  82. )
  83. ).__get__(self.model, type(self.model))
  84. def train(self):
  85. self.local_trainer.train()
  86. def terminate_local_training(self, epoch, local_dataset_len_dict, previously_selected_clients_set):
  87. local_dataset_len_dict[self.client_id] = len(self.local_train_dataset)
  88. new_adapter_weight = self.model.state_dict()
  89. single_output_dir = os.path.join(self.output_dir, str(epoch), "local_output_{}".format(self.client_id))
  90. os.makedirs(single_output_dir, exist_ok=True)
  91. torch.save(new_adapter_weight, single_output_dir + "/pytorch_model.bin")
  92. older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default")
  93. set_peft_model_state_dict(self.model, older_adapter_weight, "default")
  94. previously_selected_clients_set = previously_selected_clients_set | set({self.client_id})
  95. last_client_id = self.client_id
  96. return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id