浏览代码

[Release] v0.1.0

Weiming 2 年之前
当前提交
f919f838c2
共有 100 个文件被更改,包括 7817 次插入0 次删除
  1. 10 0
      .dockerignore
  2. 11 0
      .gitignore
  3. 16 0
      .readthedocs.yaml
  4. 201 0
      LICENSE
  5. 14 0
      Makefile
  6. 95 0
      README.md
  7. 17 0
      docker/base.Dockerfile
  8. 9 0
      docker/client.Dockerfile
  9. 105 0
      docker/docker-compose.yml
  10. 9 0
      docker/run.Dockerfile
  11. 9 0
      docker/server.Dockerfile
  12. 9 0
      docker/tracker.Dockerfile
  13. 20 0
      docs/en/Makefile
  14. 6 0
      docs/en/_static/css/readthedocs.css
  15. 二进制
      docs/en/_static/image/architecture.png
  16. 二进制
      docs/en/_static/image/docker.png
  17. 二进制
      docs/en/_static/image/easyfl-logo.png
  18. 二进制
      docs/en/_static/image/registry.png
  19. 二进制
      docs/en/_static/image/training-flow.png
  20. 40 0
      docs/en/api.rst
  21. 5 0
      docs/en/changelog.md
  22. 125 0
      docs/en/conf.py
  23. 8 0
      docs/en/faq.md
  24. 87 0
      docs/en/get_started.md
  25. 48 0
      docs/en/index.rst
  26. 51 0
      docs/en/introduction.md
  27. 35 0
      docs/en/make.bat
  28. 11 0
      docs/en/projects.md
  29. 49 0
      docs/en/quick_run.md
  30. 318 0
      docs/en/tutorials/config.md
  31. 281 0
      docs/en/tutorials/customize_server_and_client.md
  32. 232 0
      docs/en/tutorials/dataset.md
  33. 42 0
      docs/en/tutorials/distributed_training.md
  34. 126 0
      docs/en/tutorials/high-level_apis.md
  35. 11 0
      docs/en/tutorials/index.rst
  36. 92 0
      docs/en/tutorials/model.md
  37. 257 0
      docs/en/tutorials/remote_training.md
  38. 22 0
      easyfl/__init__.py
  39. 5 0
      easyfl/client/__init__.py
  40. 471 0
      easyfl/client/base.py
  41. 30 0
      easyfl/client/service.py
  42. 3 0
      easyfl/communication/__init__.py
  43. 77 0
      easyfl/communication/grpc_wrapper.py
  44. 0 0
      easyfl/compression/__init__.py
  45. 113 0
      easyfl/config.yaml
  46. 481 0
      easyfl/coordinator.py
  47. 26 0
      easyfl/datasets/__init__.py
  48. 1 0
      easyfl/datasets/cifar10/__init__.py
  49. 88 0
      easyfl/datasets/cifar10/cifar10.py
  50. 1 0
      easyfl/datasets/cifar100/__init__.py
  51. 88 0
      easyfl/datasets/cifar100/cifar100.py
  52. 243 0
      easyfl/datasets/data.py
  53. 0 0
      easyfl/datasets/data_process/__init__.py
  54. 55 0
      easyfl/datasets/data_process/cifar10.py
  55. 55 0
      easyfl/datasets/data_process/cifar100.py
  56. 10 0
      easyfl/datasets/data_process/femnist.py
  57. 142 0
      easyfl/datasets/data_process/language_utils.py
  58. 15 0
      easyfl/datasets/data_process/shakespeare.py
  59. 427 0
      easyfl/datasets/dataset.py
  60. 45 0
      easyfl/datasets/dataset_util.py
  61. 1 0
      easyfl/datasets/femnist/__init__.py
  62. 109 0
      easyfl/datasets/femnist/femnist.py
  63. 0 0
      easyfl/datasets/femnist/preprocess/__init__.py
  64. 94 0
      easyfl/datasets/femnist/preprocess/data_to_json.py
  65. 71 0
      easyfl/datasets/femnist/preprocess/get_file_dirs.py
  66. 55 0
      easyfl/datasets/femnist/preprocess/get_hashes.py
  67. 25 0
      easyfl/datasets/femnist/preprocess/group_by_writer.py
  68. 25 0
      easyfl/datasets/femnist/preprocess/match_hashes.py
  69. 1 0
      easyfl/datasets/shakespeare/__init__.py
  70. 89 0
      easyfl/datasets/shakespeare/shakespeare.py
  71. 0 0
      easyfl/datasets/shakespeare/utils/__init__.py
  72. 17 0
      easyfl/datasets/shakespeare/utils/gen_all_data.py
  73. 183 0
      easyfl/datasets/shakespeare/utils/preprocess_shakespeare.py
  74. 69 0
      easyfl/datasets/shakespeare/utils/shake_utils.py
  75. 350 0
      easyfl/datasets/simulation.py
  76. 0 0
      easyfl/datasets/utils/__init__.py
  77. 158 0
      easyfl/datasets/utils/base_dataset.py
  78. 2 0
      easyfl/datasets/utils/constants.py
  79. 176 0
      easyfl/datasets/utils/download.py
  80. 62 0
      easyfl/datasets/utils/remove_users.py
  81. 274 0
      easyfl/datasets/utils/sample.py
  82. 235 0
      easyfl/datasets/utils/split_data.py
  83. 42 0
      easyfl/datasets/utils/util.py
  84. 18 0
      easyfl/distributed/__init__.py
  85. 257 0
      easyfl/distributed/distributed.py
  86. 64 0
      easyfl/distributed/slurm.py
  87. 0 0
      easyfl/encryption/__init__.py
  88. 1 0
      easyfl/models/__init__.py
  89. 36 0
      easyfl/models/lenet.py
  90. 24 0
      easyfl/models/model.py
  91. 124 0
      easyfl/models/resnet.py
  92. 102 0
      easyfl/models/resnet18.py
  93. 107 0
      easyfl/models/resnet50.py
  94. 36 0
      easyfl/models/rnn.py
  95. 40 0
      easyfl/models/simple_cnn.py
  96. 65 0
      easyfl/models/vgg9.py
  97. 0 0
      easyfl/pb/__init__.py
  98. 75 0
      easyfl/pb/client_service_pb2.py
  99. 66 0
      easyfl/pb/client_service_pb2_grpc.py
  100. 17 0
      easyfl/pb/common_pb2.py

+ 10 - 0
.dockerignore

@@ -0,0 +1,10 @@
+.git
+.vscode
+.idea
+venv/*
+metrics/*
+__pycache__
+*_pb2.py
+*_pb2_grpc.py
+*.tar.gz
+*.npy

+ 11 - 0
.gitignore

@@ -0,0 +1,11 @@
+__pycache__
+*.pyc
+*.log
+*.csv
+*.db
+*.xls
+*.xlsx
+*.egg-info
+docs/build
+
+

+ 16 - 0
.readthedocs.yaml

@@ -0,0 +1,16 @@
+version: 2
+
+formats: all
+
+sphinx:
+  fail_on_warning: false
+  
+
+python:
+  # Install our python package before building the docs
+  version: 3.7
+  install:
+    - method: pip
+      path: .
+    - requirements: requirements/docs.txt
+    - requirements: requirements/readthedocs.txt

+ 201 - 0
LICENSE

@@ -0,0 +1,201 @@
+     Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [EasyFL] [Weiming Zhuang]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 14 - 0
Makefile

@@ -0,0 +1,14 @@
+protobuf:
+	python -m grpc_tools.protoc -I./protos \
+		--python_out=. \
+		--grpc_python_out=. \
+		protos/easyfl/pb/*.proto
+
+base_image:
+	docker build -f docker/base.Dockerfile -t easyfl:base .
+
+image:
+	docker build -f docker/client.Dockerfile -t easyfl-client .
+	docker build -f docker/server.Dockerfile -t easyfl-server .
+	docker build -f docker/tracker.Dockerfile -t easyfl-tracker .
+	docker build -f docker/run.Dockerfile -t easyfl-run .

+ 95 - 0
README.md

@@ -0,0 +1,95 @@
+<div align="center">
+  <img src="docs/en/_static/image/easyfl-logo.png" width="500"/>
+  <h1 align="center">EasyFL: A Low-code Federated Learning Platform</h1>
+
+[![PyPI](https://img.shields.io/pypi/v/easyfl)](https://pypi.org/project/easyfl)
+[![docs](https://img.shields.io/badge/docs-latest-blue)](https://easyfl.readthedocs.io/en/latest/)
+[![license](https://img.shields.io/github/license/easyfl-ai/easyfl.svg)](https://github.com/easyfl-ai/easyfl/blob/master/LICENSE)
+[![maintained](https://img.shields.io/badge/Maintained%3F-YES-yellow.svg)](https://github.com/easyfl-ai/easyfl/graphs/commit-activity)
+
+[📘 Documentation](https://easyfl.readthedocs.io/en/latest/) | [🛠️ Installation](https://easyfl.readthedocs.io/en/latest/get_started.html)
+</div>
+
+## Introduction
+
+**EasyFL** is an easy-to-use federated learning (FL) platform based on PyTorch. It aims to enable users with various levels of expertise to experiment and prototype FL applications with little/no coding. 
+
+You can use it for:
+* FL Research on algorithm and system
+* Proof-of-concept (POC) of new FL applications
+* Prototype of industrial applications
+* Learning FL implementations
+
+We currently focus on horizontal FL, supporting both cross-silo and cross-device FL. You can learn more about federated learning from these [resources](https://github.com/weimingwill/awesome-federated-learning#blogs). 
+
+## Major Features
+
+**Easy to Start**
+
+EasyFL is easy to install and easy to learn. It does not have complex dependency requirements. You can run EasyFL on your personal computer with only three lines of code ([Quick Start](docs/en/quick_run.md)).
+
+**Out-of-the-box Functionalities**
+
+EasyFL provides many out-of-the-box functionalities, including datasets, models, and FL algorithms. With simple configurations, you simulate different FL scenarios using the popular datasets. We support both statistical heterogeneity simulation and system heterogeneity simulation.
+
+**Flexible, Customizable, and Reproducible**
+
+EasyFL is flexible to be customized according to your needs. You can easily migrate existing CV or NLP applications into the federated manner by writing the PyTorch codes that you are most familiar with. 
+
+**Multiple Training Modes**
+
+EasyFL supports **standalone training**, **distributed training**, and **remote training**. By developing the code once, you can easily speed up FL training with distributed training on multiple GPUs. Besides, you can even deploy it to Kubernetes with Docker using remote training.
+
+## Getting Started
+
+You can refer to [Get Started](docs/en/get_started.md) for installation and [Quick Run](docs/en/quick_run.md) for the simplest way of using EasyFL.
+
+For more advanced usage, we provide a list of tutorials on:
+* [High-level APIs](docs/en/tutorials/high-level_apis.md)
+* [Configurations](docs/en/tutorials/config.md)
+* [Datasets](docs/en/tutorials/dataset.md)
+* [Models](docs/en/tutorials/model.md)
+* [Customize Server and Client](docs/en/tutorials/customize_server_and_client.md)
+* [Distributed Training](docs/en/tutorials/distributed_training.md)
+* [Remote Training](docs/en/tutorials/remote_training.md)
+
+
+## Projects & Papers
+
+The following publications are developed using EasyFL.
+
+- Divergence-aware Federated Self-Supervised Learning, _ICLR'2022_. [[paper]](https://openreview.net/forum?id=oVE1z8NlNe)
+- Collaborative Unsupervised Visual Representation Learning From Decentralized Data, _ICCV'2021_. [[paper]](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html)
+- Joint Optimization in Edge-Cloud Continuum for Federated Unsupervised Person Re-identification, _ACMMM'2021_. [[paper]](https://arxiv.org/abs/2108.06493)
+
+:bulb: We will release the source codes of these projects in this repository. Please stay tuned.
+
+We have been doing research on federated learning for several years, the following are our additional publications.
+
+- EasyFL: A Low-code Federated Learning Platform For Dummies, _IEEE Internet-of-Things Journal_. [[paper]](https://arxiv.org/abs/2105.07603)
+- Performance Optimization for Federated Person Re-identification via Benchmark Analysis, _ACMMM'2020_. [[paper]](https://weiming.me/publication/fedreid/)
+- Federated Unsupervised Domain Adaptation for Face Recognition, _ICME'22_. [[paper]](https://weiming.me/publication/fedfr/)
+
+## Join Our Community
+
+Please join our community on Slack: [easyfl.slack.com](https://easyfl.slack.com) 
+
+We will post updated features and answer questions on Slack.
+
+## License
+
+This project is released under the [Apache 2.0 license](LICENSE).
+
+## Citation
+
+If you use this platform or related projects in your research, please cite this project.
+
+```
+@article{zhuang2022easyfl,
+  title={Easyfl: A low-code federated learning platform for dummies},
+  author={Zhuang, Weiming and Gan, Xin and Wen, Yonggang and Zhang, Shuai},
+  journal={IEEE Internet of Things Journal},
+  year={2022},
+  publisher={IEEE}
+}
+```

+ 17 - 0
docker/base.Dockerfile

@@ -0,0 +1,17 @@
+FROM python:3.7.7-slim-buster
+
+WORKDIR /app
+
+COPY requirements.txt requirements.txt
+COPY Makefile Makefile
+COPY protos protos
+
+RUN apt-get update \
+    && apt-get install make \
+    && rm -rf /var/lib/apt/lists/*
+
+RUN pip install --upgrade pip \
+    && pip install -r requirements.txt \
+    && rm -rf ~/.cache/pip
+
+RUN make protobuf

+ 9 - 0
docker/client.Dockerfile

@@ -0,0 +1,9 @@
+FROM easyfl:base
+
+WORKDIR /app
+
+COPY . .
+
+ENV PYTHONPATH=/app:$PYTHONPATH
+
+ENTRYPOINT ["python", "examples/remote_client.py"]

+ 105 - 0
docker/docker-compose.yml

@@ -0,0 +1,105 @@
+version: "3"
+services:
+  etcd0:
+    image: quay.io/coreos/etcd:v3.4.0
+    container_name: etcd
+    ports:
+      - 23790:2379
+      - 23800:2380
+    volumes:
+      - etcd0:/etcd-data
+    environment:
+      - ETCD0=localhost
+    command:
+      - /usr/local/bin/etcd
+      - -name
+      - etcd0
+      - --data-dir
+      - /etcd_data
+      - -advertise-client-urls
+      - http://etcd0:2379
+      - -listen-client-urls
+      - http://0.0.0.0:2379
+      - -initial-advertise-peer-urls
+      - http://etcd0:2380
+      - -listen-peer-urls
+      - http://0.0.0.0:2380
+      - -initial-cluster
+      - etcd0=http://etcd0:2380
+    networks:
+      - easyfl
+
+  docker-register:
+    image: wingalong/docker-register
+    container_name: docker-regiser
+    volumes:
+      - /var/run/docker.sock:/var/run/docker.sock
+    environment:
+      - HOST_IP=172.25.0.1
+      - ETCD_HOST=etcd0:2379
+    networks:
+      - easyfl
+    depends_on:
+      - etcd0
+
+  tracker:
+    image: easyfl-tracker
+    container_name: easyfl-tracker
+    ports:
+      - "12666:12666"
+    volumes:
+      - /home/zwm/easyfl/tracker:/app/tracker
+    networks:
+      - easyfl
+    environment:
+      - PYTHONUNBUFFERED=1
+
+  client:
+    image: easyfl-client
+    ports:
+      - "23400-23500:23400"
+    volumes:
+      - /home/zwm/easyfl/easyfl/datasets/femnist/data:/app/easyfl/datasets/femnist/data
+    command: ["--is-remote", "True", "--local-port", "23400", "--server-addr", "easyfl-server:23501", "--tracker-addr", "easyfl-tracker:12666"]
+    networks:
+      - easyfl
+    environment:
+      - PYTHONUNBUFFERED=1
+    depends_on:
+      - tracker
+#      - etcd0
+#      - docker-register
+
+  server:
+    image: easyfl-server
+    container_name: easyfl-server
+    ports:
+      - "23501:23501"
+    command: ["--is-remote", "True", "--local-port", "23501", "--tracker-addr", "easyfl-tracker:12666"]
+    networks:
+      - easyfl
+    environment:
+      - PYTHONUNBUFFERED=1
+    depends_on:
+      - tracker
+#      - etcd0
+#      - docker-register
+
+#  trigger_run:
+#    image: easyfl-run
+#    command:
+#      - --server-addr
+#      - 172.21.0.1:23501
+#      - --etcd-addr
+#      - 172.21.0.1:2379
+#    networks:
+#      - easyfl
+#    depends_on:
+#      - client
+#      - server
+
+volumes:
+  etcd0:
+
+networks:
+  easyfl:

+ 9 - 0
docker/run.Dockerfile

@@ -0,0 +1,9 @@
+FROM easyfl:base
+
+WORKDIR /app
+
+COPY . .
+
+ENV PYTHONPATH=/app:$PYTHONPATH
+
+ENTRYPOINT ["python", "examples/remote_run.py"]

+ 9 - 0
docker/server.Dockerfile

@@ -0,0 +1,9 @@
+FROM easyfl:base
+
+WORKDIR /app
+
+COPY . .
+
+ENV PYTHONPATH=/app:$PYTHONPATH
+
+ENTRYPOINT ["python", "examples/remote_server.py"]

+ 9 - 0
docker/tracker.Dockerfile

@@ -0,0 +1,9 @@
+FROM easyfl:base
+
+WORKDIR /app
+
+COPY . .
+
+ENV PYTHONPATH=/app:$PYTHONPATH
+
+ENTRYPOINT ["python", "examples/remote_tracker.py"]

+ 20 - 0
docs/en/Makefile

@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS    ?=
+SPHINXBUILD   ?= sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

+ 6 - 0
docs/en/_static/css/readthedocs.css

@@ -0,0 +1,6 @@
+.header-logo {
+    background-image: url("../image/easyfl-logo.png");
+    background-size: 156px 40px;
+    width: 156px;
+    height: 40px;
+}

二进制
docs/en/_static/image/architecture.png


二进制
docs/en/_static/image/docker.png


二进制
docs/en/_static/image/easyfl-logo.png


二进制
docs/en/_static/image/registry.png


二进制
docs/en/_static/image/training-flow.png


+ 40 - 0
docs/en/api.rst

@@ -0,0 +1,40 @@
+easyfl
+--------------
+.. automodule:: easyfl
+    :members:
+
+easyfl.server
+--------------
+.. automodule:: easyfl.server
+    :members:
+
+easyfl.client
+--------------
+.. automodule:: easyfl.client
+    :members:
+
+easyfl.distributed
+------------------
+.. automodule:: easyfl.distributed
+    :members:
+
+easyfl.dataset
+--------------
+.. automodule:: easyfl.datasets
+    :members:
+
+easyfl.models
+--------------
+.. automodule:: easyfl.models
+    :members:
+
+easyfl.communication
+--------------------
+.. automodule:: easyfl.communication
+    :members:
+
+easyfl.registry
+---------------
+.. automodule:: easyfl.registry
+    :members:
+

+ 5 - 0
docs/en/changelog.md

@@ -0,0 +1,5 @@
+## Changelog
+
+### v0.1.0 (05/04/2022)
+
+- Official release

+ 125 - 0
docs/en/conf.py

@@ -0,0 +1,125 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import subprocess
+import sys
+
+import pytorch_sphinx_theme
+
+sys.path.insert(0, os.path.abspath('../..'))
+
+# -- Project information -----------------------------------------------------
+
+project = 'EasyFL'
+copyright = '2020-2022, EasyFL'
+author = 'EasyFL Authors'
+version_file = '../../easyfl/version.py'
+
+
+def get_version():
+    with open(version_file, 'r') as f:
+        exec(compile(f.read(), version_file, 'exec'))
+    return locals()['__version__']
+
+
+# # The full version, including alpha/beta/rc tags
+release = get_version()
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    'sphinx.ext.autodoc',
+    'sphinx.ext.napoleon',
+    'sphinx.ext.viewcode',
+    'recommonmark',
+    'sphinx_markdown_tables',
+    'sphinx_copybutton',
+]
+
+autodoc_mock_imports = [
+    'matplotlib', 'pycocotools', 'terminaltables', 'mmdet.version', 'mmcv.ops'
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = {
+    '.rst': 'restructuredtext',
+    '.md': 'markdown',
+}
+
+# The master toctree document.
+master_doc = 'index'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+# html_theme = 'sphinx_rtd_theme'
+html_theme = 'pytorch_sphinx_theme'
+html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
+
+html_theme_options = {
+    'menu': [
+        {
+            'name': 'Get Started',
+            'url': 'get_started.html'
+        },
+        {
+            'name': 'Tutorials',
+            'url': 'tutorials/high-level_apis.html'
+        },
+        {
+            'name': 'API',
+            'url': 'api.html'
+        },
+        {
+            'name': 'GitHub',
+            'url': 'https://github.com/EasyFL-AI/easyfl'
+        },
+    ],
+
+    # Specify the language of shared menu
+    'menu_lang': 'en'
+}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+html_css_files = ['css/readthedocs.css']
+
+# -- Extension configuration -------------------------------------------------
+# Ignore >>> when copying code
+copybutton_prompt_text = r'>>> |\.\.\. '
+copybutton_prompt_is_regexp = True
+
+
+# def builder_inited_handler(app):
+#     subprocess.run(['./stat.py'])
+
+
+# def setup(app):
+#     app.connect('builder-inited', builder_inited_handler)

+ 8 - 0
docs/en/faq.md

@@ -0,0 +1,8 @@
+# Frequently Asked Questions
+
+We list some common troubles faced by many users and their corresponding solutions here. Feel free to enrich the list if you find any frequent issues and have ways to help others to solve them. If the contents here do not cover your issue, please create an issue using the [provided templates](https://github.com/EasyFL-AI/easyfl/blob/master/.github/ISSUE_TEMPLATE/error-report.md/) and make sure you fill in all required information in the template.
+
+## EasyFL Installation
+
+Waiting for your input :)
+

+ 87 - 0
docs/en/get_started.md

@@ -0,0 +1,87 @@
+## Prerequisites
+
+- Linux or macOS (Windows is in experimental support)
+- Python 3.6+
+- PyTorch 1.3+
+- CUDA 9.2+ (If you run using GPU)
+
+## Installation
+
+### Prepare environment
+
+1. Create a conda virtual environment and activate it.
+
+    ```shell
+    conda create -n easyfl python=3.7 -y
+    conda activate easyfl
+    ```
+
+2. Install PyTorch and torchvision following the [official instructions](https://pytorch.org/), e.g.,
+
+    ```shell
+    conda install pytorch torchvision -c pytorch
+    ```
+    or
+    ```shell
+    pip install torch==1.10.1 torchvision==0.11.2
+    ```
+
+4. _You can skip the following CUDA-related content if you plan to run it on CPU._ Make sure that your compilation CUDA version and runtime CUDA version match. 
+
+    Note: Make sure that your compilation CUDA version and runtime CUDA version match.
+    You can check the supported CUDA version for precompiled packages on the [PyTorch website](https://pytorch.org/).
+
+    `E.g.,` 1. If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install
+    PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1.
+
+    ```shell
+    conda install pytorch cudatoolkit=10.1 torchvision -c pytorch
+    ```
+
+    `E.g.,` 2. If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install
+    PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2.
+
+    ```shell
+    conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch
+    ```
+
+    If you build PyTorch from source instead of installing the prebuilt package,
+    you can use more CUDA versions such as 9.0.
+
+### Install EasyFL
+
+```shell
+pip install easyfl
+```
+
+### A from-scratch setup script
+
+Assuming that you already have CUDA 10.1 installed, here is a full script for setting up MMDetection with conda.
+
+```shell
+conda create -n easyfl python=3.7 -y
+conda activate easyfl
+
+# Without GPU
+conda install pytorch==1.6.0 torchvision==0.7.0 -c pytorch -y
+
+# With GPU
+conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.1 -c pytorch -y
+
+# install easyfl
+git clone https://github.com/EasyFL-AI/easyfl.git
+cd easyfl
+pip install -v -e .
+```
+
+## Verification
+
+To verify whether EasyFL is installed correctly, we can run the following sample code to test.
+
+```python
+import easyfl
+
+easyfl.init()
+```
+
+The above code is supposed to run successfully after you finish the installation.

+ 48 - 0
docs/en/index.rst

@@ -0,0 +1,48 @@
+Welcome to MMDetection's documentation!
+=======================================
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Introduction
+
+   introduction.md
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Get Started
+
+   get_started.md
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Quick Run
+
+   quick_run.md
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Tutorials
+
+   tutorials/index.rst
+
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Notes
+
+   projects.md
+   changelog.md
+   faq.md
+
+.. toctree::
+   :maxdepth: 2
+   :caption: API Reference
+
+   api.rst
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`search`

+ 51 - 0
docs/en/introduction.md

@@ -0,0 +1,51 @@
+## Why EasyFL?
+
+**EasyFL** is an easy-to-use federated learning platform that aims to enable users with various levels of expertise to experiment and prototype FL applications with little/no coding. 
+
+You can use it for:
+* FL Research on algorithm and system
+* Proof-of-concept (POC) of new FL applications
+* Prototype of industrial applications
+* Learning FL implementations
+
+We currently focus on horizontal FL, supporting both cross-silo and cross-device FL. You can learn more about federated learning from these [resources](https://github.com/weimingwill/awesome-federated-learning#blogs). 
+
+### Major Features
+
+**Easy to Start**
+
+EasyFL is easy to install and easy to learn. It does not have complex dependency requirements. You can run EasyFL on your personal computer with only three lines of code ([Quick Start](quick_run.md)).
+
+**Out-of-the-box Functionalities**
+
+EasyFL provides many out-of-the-box functionalities, including datasets, models, and FL algorithms. With simple configurations, you simulate different FL scenarios using the popular datasets. We support both statistical heterogeneity simulation and system heterogeneity simulation.
+
+**Flexible, Customizable, and Reproducible**
+
+EasyFL is flexible to be customized according to your needs. You can easily migrate existing CV or NLP applications into the federated manner by writing the PyTorch codes that you are most familiar with. 
+
+**Multiple Training Modes**
+
+EasyFL supports **standalone training**, **distributed training**, and **remote training**. By developing the code once, you can easily speed up FL training with distributed training on multiple GPUs. Besides, you can even deploy it to Kubernetes with Docker using remote training.
+
+We have developed many applications and published several [papers](projects.md) in top-tier conferences and journals using EasyFL. We believe that EasyFL will empower you with FL research and development.
+
+## Architecture Overview
+
+Here we introduce the architecture of EasyFL. You can jump directly to [Get Started](get_started.md) without knowing these details.
+
+EasyFL's architecture comprises of an **interface layer** and a modularized **system layer**. The interface layer provides simple APIs for high-level applications and the system layer has complex implementations to accelerate training and shorten deployment time with out-of-the-box functionalities.
+
+![architecture](_static/image/architecture.png)
+
+**Interface Layer**: The interface layer provides a common interface across FL applications. It contains APIs that are designed to encapsulate complex system implementations from users. These APIs decouple application-specific models, datasets, and algorithms such that EasyFL is generic to support a wide range of applications like computer vision and healthcare.
+ 
+**System Layer**: The system layer supports and manages the FL life cycle. It consists of eight modules to support FL training pipeline and life cycle: 
+1. The simulation manager initializes the experimental environment with heterogeneous simulations. 
+2. The data manager loads training and testing datasets, and the model manager loads the model. 
+3. A server and the clients start training and testing with FL algorithms such as FedAvg. 
+4. The distribution manager optimizes the training speed of distributed training. 
+5. The tracking manager collects the evaluation metrics and provides methods to query training results. 
+6. The deployment manager seamlessly deploys FL and scales FL applications in production.
+
+To learn more about EasyFL, you can check out our [paper](https://arxiv.org/abs/2105.07603).

+ 35 - 0
docs/en/make.bat

@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+	set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+	echo.
+	echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+	echo.installed, then set the SPHINXBUILD environment variable to point
+	echo.to the full path of the 'sphinx-build' executable. Alternatively you
+	echo.may add the Sphinx directory to PATH.
+	echo.
+	echo.If you don't have Sphinx installed, grab it from
+	echo.http://sphinx-doc.org/
+	exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd

+ 11 - 0
docs/en/projects.md

@@ -0,0 +1,11 @@
+# Projects based on EasyFL
+
+We have built several projects based on EasyFL and published four papers in top-tier conferences and journals. 
+We list them as examples of how to extend EasyFL for your projects.
+
+- EasyFL: A Low-code Federated Learning Platform For Dummies, _IEEE Internet-of-Things Journal_. [[paper]](https://arxiv.org/abs/2105.07603)
+- Divergence-aware Federated Self-Supervised Learning, _ICLR'2022_. [[paper]](https://openreview.net/forum?id=oVE1z8NlNe)
+- Collaborative Unsupervised Visual Representation Learning From Decentralized Data, _ICCV'2021_. [[paper]](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html)
+- Joint Optimization in Edge-Cloud Continuum for Federated Unsupervised Person Re-identification, _ACMMM'2021_. [[paper]](https://arxiv.org/abs/2108.06493)
+
+If you have built new projects using EasyFL, please feel free to create PR to update this page.

+ 49 - 0
docs/en/quick_run.md

@@ -0,0 +1,49 @@
+## High-level Introduction
+
+EasyFL provides numerous existing models and datasets. Models include LeNet, RNN, VGG9, and ResNet. Datasets include Femnist, Shakespeare, CIFAR-10, and CIFAR-100. 
+This note will present how to start training with these existing models and standard datasets.
+
+EasyFL provides three types of high-level APIs: **registration**, **initialization**, and **execution**.
+Registration is for registering customized components, which we will introduce in the following notes.
+In this note, we focus on **initialization** and **execution**.
+
+## Simplest Run
+
+We can run federated learning with only two lines of code (not counting the import statement).
+It executes training with default configurations: simulating 100 clients with the FEMNIST dataset and randomly selecting 5 clients for training in each training round.
+We explain more about the configurations in [another note](tutorials/config.md).
+
+Note: we package default partitioning of Femnist data to avoid downloading the whole dataset.
+
+```python
+import easyfl
+
+# Initialize federated learning with default configurations.
+easyfl.init()
+# Execute federated learning training.
+easyfl.run()
+```
+
+## Run with Configurations
+
+You can specify configurations to overwrite the default configurations.
+
+```python
+import easyfl
+
+# Customized configuration.
+config = {
+    "data": {"dataset": "cifar10", "split_type": "class", "num_of_clients": 100},
+    "server": {"rounds": 5, "clients_per_round": 2},
+    "client": {"local_epoch": 5},
+    "model": "resnet18",
+    "test_mode": "test_in_server",
+}
+# Initialize federated learning with default configurations.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+In the example above, we run training with model ResNet-18 and CIFAR-10 dataset that is partitioned into 100 clients by label `class`.
+It runs training with 2 clients per round for 5 rounds. In each round, each client trains 5 epochs.

+ 318 - 0
docs/en/tutorials/config.md

@@ -0,0 +1,318 @@
+# Tutorial 2: Configurations
+
+Configurations in EasyFL are to control and config the federated learning (FL) training behavior. It instructs data simulation, the model for training, training hyperparameters, distributed training, etc. 
+
+We provide [default configs](#default-configurations) in EasyFL, while there are two ways you can modify the configs of EasyFL: using Python and using a yaml file.
+
+## Modify Config
+
+EasyFL provides two ways to modify the configurations: using Python dictionary and using a yaml file. Either way, if the new configs exist in the default configuration, they overwrite those specific fields. If the new configs do not exist, it adds them to the EasyFL configuration. Thus, you can either modify the default configurations or add new configurations based on your application needs.
+
+### 1. Modify Using Python Dictionary
+
+You can create a new Python dictionary to specify configurations. These configs take effect when you initialize EasyFL with them by calling `easyfl.init(config)`.   
+ 
+The examples provided in the previous [tutorial](high-level_apis.md) demonstrate how to modify config via a Python dictionary.  
+```python
+import easyfl
+
+# Define customized configurations.
+config = {
+    "data": {
+        "dataset": "cifar10", 
+        "num_of_clients": 1000
+    },
+    "server": {
+        "rounds": 5, 
+        "clients_per_round": 2
+    },
+    "client": {"local_epoch": 5},
+    "model": "resnet18",
+    "test_mode": "test_in_server",
+}
+# Initialize EasyFL with the new config.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+### 2. Modify Using A Yaml File
+
+You can create a new yaml file named `config.yaml` for configuration and load them into EasyFL.
+
+```python
+import easyfl
+# Define customized configurations in a yaml file.
+config_file = "config.yaml"
+# Load the yaml file as config.
+config = easyfl.load_config(config_file)
+# Initialize EasyFL with the new config.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+You can also combine these two methods of modifying configs.
+
+```python
+import easyfl
+
+# Define part of customized configs.
+config = {
+    "data": {
+        "dataset": "cifar10", 
+        "num_of_clients": 1000
+    },
+    "server": {
+        "rounds": 5, 
+        "clients_per_round": 2
+    },
+    "client": {"local_epoch": 5},
+    "model": "resnet18",
+    "test_mode": "test_in_server",
+}
+
+# Define part of configs in a yaml file.
+config_file = "config.yaml"
+# Load and combine these two configs.
+config = easyfl.load_config(config_file, config)
+# Initialize EasyFL with the new config.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+## A Common Practice to Modify Configuration
+
+Since some configurations are directly related to training, we may need to set them dynamically with different values. 
+
+For example, we may want to experiment with the effect of batch size and local epoch on federated learning. Instead of changing the value manually each time in configuration, you can pass the value in as command-line arguments and set the value with different commands.
+
+```python
+import easyfl
+import argparse
+
+# Define command line arguments.
+parser = argparse.ArgumentParser(description='Example')
+parser.add_argument("--batch_size", type=int, default=32)
+parser.add_argument("--local_epoch", type=int, default=5)
+args = parser.parse_args()
+print("args", args)
+
+# Define customized configurations using the arguments.
+config = {
+    "client": {
+        "batch_size": args.batch_size,
+        "local_epoch": args.local_epoch,
+    }
+}
+# Initialize EasyFL with the new config.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+
+## Default Configurations
+
+The followings are the default configurations in EasyFL. 
+They are copied from `easyfl/config.yaml` on April, 2022.
+
+We provide more details on how to simulate different FL scenarios with the out-of-the-box datasets in [another note](dataset.md).  
+
+```yaml
+# The unique identifier for each federated learning task
+task_id: ""
+
+# Provide dataset and federated learning simulation related configuration.
+data:
+  # The root directory where datasets are stored.
+  root: "./data/"
+  # The name of the dataset, support: femnist, shakespeare, cifar10, and cifar100.
+  dataset: femnist
+  # The data distribution of each client, support: iid, niid (for femnist and shakespeare), and dir and class (for cifar datasets).
+    # `iid` means independent and identically distributed data.
+    # `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
+    # `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+    # `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+  split_type: "iid"
+  
+  # The minimal number of samples in each client. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+  min_size: 10
+  # The fraction of data sampled for LEAF datasets. e.g., 10% means that only 10% of total dataset size are used.
+  data_amount: 0.05
+  # The fraction of the number of clients used when the split_type is 'iid'.
+  iid_fraction: 0.1
+  # Whether partition users of the dataset into train-test groups. Only applicable to femnist and shakespeare datasets.
+    # True means partitioning users of the dataset into train-test groups.
+    # False means partitioning each users' samples into train-test groups.
+  user: False
+  # The fraction of data for training; the rest are for testing.
+  train_test_split: 0.9
+
+  # The number of classes in each client. Only applicable when the split_type is 'class'.  
+  class_per_client: 1
+  # The targeted number of clients to construct.used in non-leaf dataset, number of clients split into. for leaf dataset, only used when split type class.
+  num_of_clients: 100
+
+  # The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir` for CIFAR datasets.
+  alpha: 0.5
+
+  # The targeted distribution of quantities to simulate data quantity heterogeneity.
+    # The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+    # The `num_of_clients` should be divisible by `len(weights)`.
+    # None means clients are simulated with the same data quantity.
+  weights: NULL
+
+# The name of the model for training, support: lenet, rnn, resnet, resnet18, resnet50, vgg9.
+model: lenet
+# How to conduct testing, options: test_in_client or test_in_server.
+  # `test_in_client` means that each client has a test set to run testing.
+  # `test_in_server` means that server has a test set to run testing for the global model. Use this mode for cifar datasets.
+test_mode: "test_in_client"
+# The way to measure testing performance (accuracy) when test mode is `test_in_client`, support: average or weighted (means weighted average).
+test_method: "average"
+
+server:
+  track: False  # Whether track server metrics using the tracking service.
+  rounds: 10  # Total training round.
+  clients_per_round: 5  # The number of clients to train in each round.
+  test_every: 1  # The frequency of testing: conduct testing every N round.
+  save_model_every: 10  # The frequency of saving model: save model every N round.
+  save_model_path: ""  # The path to save model. Default path is root directory of the library.
+  batch_size: 32  # The batch size of test_in_server.
+  test_all: True  # Whether test all clients or only selected clients.
+  random_selection: True  # Whether select clients to train randomly.
+  # The strategy to aggregate client uploaded models, options: FedAvg, equal.
+    # FedAvg aggregates models using weighted average, where the weights are data size of clients.
+    # equal aggregates model by simple averaging.
+  aggregation_stragtegy: "FedAvg"
+  # The content of aggregation, options: all, parameters.
+    # all means aggregating models using state_dict, including both model parameters and persistent buffers like BatchNorm stats.
+    # parameters means aggregating only model parameters.
+  aggregation_content: "all"
+
+client:
+  track: False  # Whether track server metrics using the tracking service.
+  batch_size: 32  # The batch size of training in client.
+  test_batch_size: 5  # The batch size of testing in client.
+  local_epoch: 10  # The number of epochs to train in each round.
+  optimizer:
+    type: "Adam"  # The name of the optimizer, options: Adam, SGD.
+    lr: 0.001
+    momentum: 0.9
+    weight_decay: 0
+  seed: 0
+  local_test: False  # Whether test the trained models in clients before uploading them to the server.
+
+gpu: 0  # The total number of GPUs used in training. 0 means CPU.
+distributed:  # The distributed training configurations. It is only applicable when gpu > 1.
+  backend: "nccl"  # The distributed backend.
+  init_method: ""
+  world_size: 0
+  rank: 0
+  local_rank: 0
+
+tracking:  # The configurations for logging and tracking.
+  database: ""  # The path of local dataset, sqlite3.
+  log_file: ""
+  log_level: "INFO"  # The level of logging.
+  metric_file: ""
+  save_every: 1
+
+# The configuration for system heterogeneity simulation.
+resource_heterogeneous:
+  simulate: False  # Whether simulate system heterogeneity in federated learning.
+  # The type of heterogeneity to simulate, support iso, dir, real.
+    # iso means that
+  hetero_type: "real"
+  level: 3  # The level of heterogeneous (0-5), 0 means no heterogeneous among clients.
+  sleep_group_num: 1000  # The number of groups with different sleep time. 1 means all clients are the same.
+  total_time: 1000  # The total sleep time of all clients, unit: second.
+  fraction: 1  # The fraction of clients attending heterogeneous simulation.
+  grouping_strategy: "greedy"  # The grouping strategy to handle system heterogeneity, support: random, greedy, slowest.
+  initial_default_time: 5  # The estimated default training time for each training round, unit: second.
+  default_time_momentum: 0.2  # The default momentum for default time update.
+
+seed: 0  # The random seed.
+```
+
+### Default Config without Comments
+
+```yaml
+task_id: ""
+data:
+  root: "./data/"
+  dataset: femnist
+  split_type: "iid"
+  
+  min_size: 10
+  data_amount: 0.05
+  iid_fraction: 0.1
+  user: False
+  
+  class_per_client: 1
+  num_of_clients: 100
+  train_test_split: 0.9  
+  alpha: 0.5
+  
+  weights: NULL
+  
+model: lenet
+test_mode: "test_in_client"
+test_method: "average"
+
+server:
+  track: False
+  rounds: 10
+  clients_per_round: 5
+  test_every: 1
+  save_model_every: 10
+  save_model_path: ""
+  batch_size: 32
+  test_all: True
+  random_selection: True
+  aggregation_stragtegy: "FedAvg"
+  aggregation_content: "all"
+
+client:
+  track: False
+  batch_size: 32
+  test_batch_size: 5
+  local_epoch: 10
+  optimizer:
+    type: "Adam"
+    lr: 0.001
+    momentum: 0.9
+    weight_decay: 0
+  seed: 0
+  local_test: False
+
+gpu: 0
+distributed:
+  backend: "nccl"
+  init_method: ""
+  world_size: 0
+  rank: 0
+  local_rank: 0
+
+tracking:
+  database: ""
+  log_file: ""
+  log_level: "INFO"
+  metric_file: ""
+  save_every: 1
+
+resource_heterogeneous:
+  simulate: False
+  hetero_type: "real"
+  level: 3
+  sleep_group_num: 1000
+  total_time: 1000
+  fraction: 1
+  grouping_strategy: "greedy"
+  initial_default_time: 5
+  default_time_momentum: 0.2
+
+seed: 0  
+```

+ 281 - 0
docs/en/tutorials/customize_server_and_client.md

@@ -0,0 +1,281 @@
+# Tutorial 5: Customize Server and Client
+
+EasyFL abstracts the federated learning (FL) training flow in the server and the client into granular stages, as shown in the image below.
+
+![Training Flow](../_static/image/training-flow.png)
+
+You have the flexibility to customize any stage of the training flow while reusing the rest by implementing a customized server/client.  
+
+## Customize Server
+
+EasyFL implements random client selection and [Federated Averaging](https://arxiv.org/abs/1602.05629) as the aggregation strategy. 
+You can customize the server implementation by inheriting [BaseServer](../api.html#easyfl.server.BaseServer) and override specific functions.
+
+Below is an example of a customized server. 
+```python
+import easyfl
+from easyfl.server import BaseServer
+from easyfl.server.base import MODEL
+
+class CustomizedServer(BaseServer):
+    def __init__(self, conf, **kwargs):
+        super(CustomizedServer, self).__init__(conf, **kwargs)
+        pass  # more initialization of attributes.
+    
+    def aggregation(self):
+        uploaded_content = self.get_client_uploads()
+        models = list(uploaded_content[MODEL].values())
+        # Original implementation of aggregation weights
+        # weights = list(uploaded_content[DATA_SIZE].values())
+        # We can assign the manipulated customized weights in aggregation.   
+        customized_weights = list(range(len(models)))
+        model = self.aggregate(models, customized_weights)
+        self.set_model(model, load_dict=True)
+
+# Register customized server.
+easyfl.register_server(CustomizedServer)
+# Initialize federated learning with default configurations.
+easyfl.init()
+# Execute federated learning training.
+easyfl.run()
+```
+
+Here we list down more useful functions to override to implement a customized server.  
+
+```python
+import easyfl
+from easyfl.server import BaseServer
+
+class CustomizedServer(BaseServer):
+    def __init__(self, conf, **kwargs):
+        super(CustomizedServer, self).__init__(conf, **kwargs)
+        pass  # more initialization of attributes.
+    
+    def selection(self, clients, clients_per_round):
+        pass  # implement customized client selection algorithm.
+    
+    def compression(self):
+        pass  # implement customized compression algorithm.
+    
+    def pre_train(self):
+        pass  # inject operations before distribution to train.
+    
+    def post_train(self):
+        pass  # inject operations after aggregation.
+    
+    def pre_test(self):
+        pass  # inject operations before distribution to test. 
+    
+    def post_test(self):
+        pass  # inject operations after aggregating testing results.
+    
+    def decompression(self, model):
+        pass  # implement customized decompression algorithm.
+    
+    def aggregation(self):
+        pass  # implement customized aggregation algorithm.
+```
+
+Below are some attributes that you may need in implementing the customized server.
+
+`self.conf`: Configurations of EasyFL.
+
+`self._model`: The global model in server, updated after aggregation.
+
+`self._current_round`: The current training round.
+
+`self._clients`: All available clients.
+
+`self.selected_clients`: The selected clients.
+
+You may refer to the [BaseServer](../api.html#easyfl.server.BaseServer) for more functions and class attributes.
+
+## Customize Client
+
+Each client of EasyFL conducts training and testing. 
+The implementation of training and testing is similar to normal PyTorch implementation.
+We implement training with Adam/SGD optimizer using CrossEntropy loss. 
+You can customize client implementation of training and testing by inheriting [BaseClient](../api.html#easyfl.client.BaseClient) and overriding specific functions. 
+
+Below is an example of a customized client. 
+
+```python
+import time
+import easyfl
+from torch import nn
+import torch.optim as optim
+from easyfl.client.base import BaseClient
+
+# Inherit BaseClient to implement customized client operations.
+class CustomizedClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
+        super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
+        # Initialize a classifier for each client.
+        self.classifier = nn.Sequential(*[nn.Linear(512, 100)])
+
+    def train(self, conf, device):
+        start_time = time.time()
+        self.model.classifier.classifier = self.classifier.to(device)
+        loss_fn, optimizer = self.pretrain_setup(conf, device)
+        self.train_loss = []
+        for i in range(conf.local_epoch):
+            batch_loss = []
+            for batched_x, batched_y in self.train_loader:
+                x, y = batched_x.to(device), batched_y.to(device)
+                optimizer.zero_grad()
+                out = self.model(x)
+                loss = loss_fn(out, y)
+                loss.backward()
+                optimizer.step()
+                batch_loss.append(loss.item())
+            current_epoch_loss = sum(batch_loss) / len(batch_loss)
+            self.train_loss.append(float(current_epoch_loss))
+        self.train_time = time.time() - start_time
+        # Keep the classifier in clients and upload only the backbone of model. 
+        self.classifier = self.model.classifier.classifier
+        self.model.classifier.classifier = nn.Sequential()        
+
+    # A customized optimizer that sets different learning rates for different model parts.
+    def load_optimizer(self, conf):
+        ignored_params = list(map(id, self.model.classifier.parameters()))
+        base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters())
+        optimizer = optim.SGD([
+            {'params': base_params, 'lr': 0.1 * conf.optimizer.lr},
+            {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr}
+        ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True)
+        return optimizer
+
+# Register customized client.
+easyfl.register_client(CustomizedClient)
+# Initialize federated learning with default configurations.
+easyfl.init()
+# Execute federated learning training.
+easyfl.run()
+```
+
+Here we list down more useful functions to override to implement a customized client.  
+
+```python
+import easyfl
+from easyfl.client import BaseClient
+
+# Inherit BaseClient to implement customized client operations.
+class CustomizedClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
+        super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
+        pass  # more initialization of attributes.
+
+    def decompression(self):
+        pass  # implement decompression method.
+
+    def pre_train(self):
+        pass  # inject operations before training. 
+
+    def train(self, conf, device):
+        pass  # implement customized training method.
+    
+    def post_train(self):
+        pass  # inject operations after training.
+    
+    def load_loss_fn(self, conf):
+        pass  # load a customized loss function.
+        return loss
+
+    def load_optimizer(self, conf):
+        pass  # load a customized optimizer
+        return optimizer
+
+    def load_loader(self, conf):
+        pass  # load a customized data loader.
+        return train_loader
+
+    def test_local(self):
+        pass  # implement testing of the trained model before uploading to the server.
+
+    def pre_test(self):
+        pass  # inject operations before testing. 
+
+    def test(self, conf, device):
+        pass  # implement customized testing.
+    
+    def post_test(self):
+        pass  # inject operations after testing.
+
+    def encryption(self):
+        pass  # implement customized encryption method.
+
+    def compression(self):
+        pass  # implement customized compression method.
+
+    def upload(self):
+        pass  # implement customized upload method.
+
+    def post_upload(self):
+        pass  # implement customized post upload method.
+```
+
+Below are some attributes that you may need in implementing the customized client.
+
+`self.conf`: Configurations of client, under key "client" of config dictionary.
+
+`self.compressed_model`: The model downloaded from the server.
+
+`self.model`: The model used for training.
+
+`self.cid`: The client id.
+
+`self.device`: The device for training. 
+
+`self.train_data`: The training data of the client.
+
+`self.test_data`: The testing data of the client.
+
+You may refer to the [BaseClient](../api.html#easyfl.client.BaseClient) for more functions and class attributes.
+
+## Existing Works
+
+We surveyed 33 papers from recent publications of FL from both the machine learning and system community. 
+The following table shows that 10 out of 33 (~30%) publications propose new algorithms with changes in only one stage of the training flow, and the majority (~57%) change only two stages. 
+Training flow abstraction you to focus on the problems, without re-implementing the whole FL process.
+
+Annotation of the table: 
+
+_Server stages_: **Sel** -- Selection, **Com** -- Compression, **Agg** -- Aggregation
+
+_Client stages_: **Train**, **Com** -- Compression, **Enc** -- Encryption
+
+| Revenue | Title | Sel | Com | Agg | Train | Com | Enc |
+| :--- | :---: | :---: | :---: | :---: | :---: | :---: | ---: |
+| INFOCOM'20 | Optimizing Federated Learning on Non-IID Data with Reinforcement Learning | ✓ | | | | | | |
+| OSDI'21 | Oort: Informed Participant Selection for Scalable Federated Learning | ✓ | | | | | | |
+| HPDC'20 | TiFL: A Tier-based Federated Learning System | ✓ | | | | | | |
+| IoT'21 | FedMCCS: Multicriteria Client Selection Model for Optimal IoT Federated Learning | ✓ | | | | | | |
+| KDD'20 | FedFast: Going Beyond Average for Faster Training of Federated Recommender Systems | ✓ | | ✓ | | | | |
+| TNNLS 2019 | Robust and Communication-Efficient Federated Learning From Non-i.i.d. Data | | ✓ | | | | ✓ | |
+| NIPS'20 | Ensemble Distillation for Robust Model Fusion in Federated Learning | | | ✓ | | | | |
+| ICDCS 2019 | CMFL: Mitigating Communication Overhead for Federated Learning | | | | | | ✓ | |
+| ICML'20 | FetchSGD: Communication-Efficient Federated Learning with Sketching | | | ✓ | | | ✓ | |
+| ICML'20 | SCAFFOLD: Stochastic Controlled Averaging for Federated Learning | | | ✓ | | | ✓ | |
+| TPDS'20 | FedSCR: Structure-Based Communication Reduction for Federated Learning | | | ✓ | | | ✓ | |
+| HotEdge 2018 | eSGD: Communication Efficient Distributed Deep Learning on the Edge | | | | | | ✓ | |
+| ICML'20 | Adaptive Federated Optimization | | | | | ✓ | | |
+| CVPR'21 | Privacy-preserving Collaborative Learning with Automatic Transformation Search | | | | | ✓ | | |
+| MLSys'20 | Federated Optimization in Heterogeneous Networks | | | | | ✓ | | |
+| ICLR'20 | Federated Learning with Matched Averaging | | | ✓ | | ✓ | | |
+| ACMMM'20 | Performance Optimization for Federated Person Re-identification via Benchmark Analysis | | | ✓ | | ✓ | | |
+| NIPS'20 | Distributionally Robust Federated Averaging | | | ✓ | | ✓ | | |
+| NIPS'20 | Group Knowledge Transfer: Federated Learning of Large CNNs at the Edge | | | ✓ | | ✓ | | |
+| NIPS'20 | Personalized Federated Learning with Moreau Envelopes | | | ✓ | | ✓ | | |
+| ICLR'20 | Fair Resource Allocation in Federated Learning | | | ✓ | | ✓ | | |
+| ICML'20 | Federated Learning with Only Positive Labels | | | ✓ | | ✓ | | |
+| AAAI'21 | Addressing Class Imbalance in Federated Learning | | | ✓ | | ✓ | | |
+| AAAI'21 | Federated Block Coordinate Descent Scheme for Learning Global and Personalized Models | | | ✓ | | ✓ | | |
+| IoT'20 | Toward Communication-Efficient Federated Learning in the Internet of Things With Edge Computing | | | ✓ | | ✓ | ✓ | |
+| ICML'20 | Acceleration for Compressed Gradient Descent in Distributed and Federated Optimization | | | ✓ | | ✓ | ✓ | |
+| INFOCOMM 2018 | When Edge Meets Learning: Adaptive Control for Resource-Constrained Distributed Machine Learning | | | ✓ | | ✓ | | |
+| ATC'20 | BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning | | | ✓ | | | | ✓ |
+| AAAI'21 | FLAME: Differentially Private Federated Learning in the Shuffle Model | | | ✓ | | | | ✓ |
+| TIFS'20 | Federated Learning with Differential Privacy: Algorithms and Performance Analysis | | | ✓ | | | | ✓ |
+| GLOBECOM'20 | Towards Efficient Secure Aggregation for Model Update in Federated Learning | | | ✓ | | | | ✓ |
+| MobiCom'20 | Billion-Scale Federated Learning on Mobile Clients: A Submodel Design with Tunable Privacy | | | ✓ | | ✓ | | ✓ |
+| IoT'20 | Privacy-Preserving Federated Learning in Fog Computing | | | ✓ | | ✓ | | ✓ |

+ 232 - 0
docs/en/tutorials/dataset.md

@@ -0,0 +1,232 @@
+# Tutorial 3: Datasets
+
+In this note, we present how to use the out-of-the-box datasets to simulate different federated learning (FL) scenarios.
+Besides, we introduce how to use the customized dataset in EasyFL.
+
+We currently provide four out-of-the-box datasets: FEMNIST, Shakespeare, CIFAR-10, and CIFAR-100. FEMNIST and
+Shakespeare are adopted from [LEAF benchmark](https://leaf.cmu.edu/). We plan to integrate and provide more
+out-of-the-box datasets in the future.
+
+## Out-of-the-box Datasets
+
+The simulation of different FL scenarios is configured in the configurations. You can refer to the
+other [tutorial](config.md) to learn more about how to modify configs. In this note, we focus on how to config the
+datasets with different simulations.
+
+The following are dataset configurations.
+
+```yaml
+data:
+  # The root directory where datasets are stored.
+  root: "./data/"
+  # The name of the dataset, support: femnist, shakespeare, cifar10, and cifar100.
+  dataset: femnist
+    # The data distribution of each client, support: iid, niid (for femnist and shakespeare), and dir and class (for cifar datasets).
+    # `iid` means independent and identically distributed data.
+    # `niid` means non-independent and identically distributed data for FEMNIST and Shakespeare.
+    # `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+  # `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+  split_type: "iid"
+
+  # The minimal number of samples in each client. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+  min_size: 10
+  # The fraction of data sampled for LEAF datasets. e.g., 10% means that only 10% of the total dataset size is used.
+  data_amount: 0.05
+  # The fraction of the number of clients used when the split_type is 'iid'.
+  iid_fraction: 0.1
+    # Whether partition users of the dataset into train-test groups. Only applicable to femnist and shakespeare datasets.
+    # True means partitioning users of the dataset into train-test groups.
+  # False means partitioning each users' samples into train-test groups.
+  user: False
+  # The fraction of data for training; the rest are for testing.
+  train_test_split: 0.9
+
+  # The number of classes in each client. Only applicable when the split_type is 'class'.  
+  class_per_client: 1
+  # The targeted number of clients to construct.used in non-leaf dataset, number of clients split into. for leaf dataset, only used when split type class.
+  num_of_clients: 100
+  # The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir` for CIFAR datasets.
+  alpha: 0.5
+
+    # The targeted distribution of quantities to simulate data quantity heterogeneity.
+    # The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+    # The `num_of_clients` should be divisible by `len(weights)`.
+  # None means clients are simulated with the same data quantity.
+  weights: NULL
+```
+
+Among them, `root` is applicable to all datasets. It specifies the directory to store datasets.
+
+EasyFL automatically downloads a dataset if it is not exist in the root directory.
+
+Next, we introduce the simulation and configuration for specific datasets.
+
+### FEMNIST and Shakespeare Datasets
+
+The following are basic stats of these two datasets.
+
+FEMNIST
+
+* Overview: Image Dataset
+* Details: 3500 users, 62 different classes (10 digits, 26 lowercase, 26 uppercase), images are 28 by 28 pixels (with
+  option to make them all 128 by 128 pixels)
+* Task: Image Classification
+
+Shakespeare
+
+* Overview: Text Dataset of Shakespeare Dialogues
+* Details: 1129 users (reduced to 660 with our choice of sequence length.)
+* Task: Next-Character Prediction
+
+The datasets are non-IID (independent and identically distributed) in nature.
+
+`split_type`: There are two options for these two datasets: `iid` and `niid`, representing IID data simulation and
+non-IID data simulation.
+
+Five hyper-parameters determine the simulated dataset: `min_size`, `data_amount`, `iid_fraction`, `tran_test_split`,
+and `user`.
+
+`user` is a boolean that determines whether to partition the dataset to train test group by user or samples.
+`user: True` means partitioning users of the dataset into train-test groups, i.e. some users are for training, some
+users are for testing.
+`user: False` means partitioning each users' samples into train-test groups, i.e. data in each client is partitioned
+into training set and testing set.
+
+Note: we normally use `test_mode: test_in_clients` for these two datasets.
+
+#### IID Simulation
+
+In IID simulation, data are randomly partitioned into multiple clients.
+
+The number of clients is determined by `data_amount` and `iid_fraction`.
+
+#### Non-IID Simulation
+
+Since FEMNIST and Shakespeare are non-IID in nature, each user of the dataset is regarded as a client.
+
+`data_amount` determine the number of clients participate in training.
+
+### CIFAR-10 and CIFAR-100 Datasets
+
+> The **CIFAR-10** dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
+
+> The **CIFAR-100** dataset consists of 60000 32x32 colour images in 100 classes, with 600 images per class. There are 50000 training images and 10000 test images.
+
+`split_type`: There are three options for CIFAR datasets: `iid`, `dir`, and `class`.
+
+Three hyper-parameters determine the simulated dataset: `num_of_clients`, `class_per_client`, and `alpha`.
+
+#### IID Simulation
+
+In IID simulation, the training images of the datasets are randomly partitioned into `num_of_clients` clients.
+
+#### Non-IID Simulation
+
+We can simulate non-IID CIFAR datasets by Dirichlet process (`dir`) or by label class (`class`).
+
+`alpha` controls the level of heterogeneity for `dir` simulation.
+
+`class_per_client` determines the number of classes in each client.
+
+## Customize Datasets
+
+EasyFL also supports integrating with customized dataset to simulate federated learning.
+
+You can use the following classes to integrate customized dataset: [FederatedImageDataset](../api.html#easyfl.datasets.FederatedImageDataset), [FederatedTensorDataset](../api.html#easyfl.datasets.FederatedTensorDataset), and [FederatedTorchDataset](../api.html#easyfl.datasets.FederatedTorchDataset).
+
+The following is an example that integrates [nine person re-identification datasets](https://arxiv.org/abs/2008.11560), where each client contains one dataset.
+
+```python
+import easyfl
+import os
+from torchvision import transforms
+from easyfl.datasets import FederatedImageDataset
+
+TRANSFORM_TRAIN_LIST = transforms.Compose([
+    transforms.Resize((256, 128), interpolation=3),
+    transforms.Pad(10),
+    transforms.RandomCrop((256, 128)),
+    transforms.RandomHorizontalFlip(),
+    transforms.ToTensor(),
+    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+TRANSFORM_VAL_LIST = transforms.Compose([
+    transforms.Resize(size=(256, 128), interpolation=3),
+    transforms.ToTensor(),
+    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+
+DATASETS = ["MSMT17", "Duke", "Market", "cuhk03", "prid", "cuhk01", "viper", "3dpes", "ilids"]
+
+# Prepare customized training data
+def prepare_train_data(data_dir):
+    client_ids = []
+    roots = []
+    for db in DATASETS:
+        client_ids.append(db)
+        data_path = os.path.join(data_dir, db, "pytorch")
+        roots.append(os.path.join(data_path, "train_all"))
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_TRAIN_LIST,
+                                 client_ids=client_ids)
+    return data
+
+
+# Prepare customized testing data
+def prepare_test_data(data_dir):
+    roots = []
+    client_ids = []
+    for db in DATASETS:
+        test_gallery = os.path.join(data_dir, db, 'pytorch', 'gallery')
+        test_query = os.path.join(data_dir, db, 'pytorch', 'query')
+        roots.extend([test_gallery, test_query])
+        client_ids.extend([f"{db}_gallery", f"{db}_query"])
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_VAL_LIST,
+                                 client_ids=client_ids)
+    return data
+
+
+if __name__ == '__main__':
+    config = {...}
+    data_dir = "datasets/"
+    train_data, test_data = prepare_train_data(data_dir), prepare_test_data(data_dir)
+    easyfl.register_dataset(train_data, test_data)
+    easyfl.init(config)
+    easyfl.run()
+```
+
+The folder structure of these datasets are as followed:
+```
+|-- MSMT17
+|   |-- pytorch
+|   |  	  |-- gallery
+|   |     |-- query
+|   |     |-- train
+|   |     |-- train_all
+|   |     `-- val
+|-- cuhk01
+|   |-- pytorch
+|   |  	  |-- gallery
+|   |     |-- query
+|   |     |-- train
+|   |     |-- train_all
+| ...
+```
+
+Please [email us](mailto:weiming001@e.ntu.edu.sg) if you want to access these datasets with:
+1. A short self-introduction.
+2. The purposes of using these datasets.
+
+*⚠️ Further distribution of the datasets are prohibited.*
+
+### Create Your Own Federated Dataset
+
+In case that the provided federated dataset class is not enough, 
+you can implement your own federated dataset by inherit and implement [FederatedDataset](../api.html#easyfl.datasets.FederatedDataset).
+
+You can refer to [FederatedImageDataset](../api.html#easyfl.datasets.FederatedImageDataset), [FederatedTensorDataset](../api.html#easyfl.datasets.FederatedTensorDataset), and [FederatedTorchDataset](../api.html#easyfl.datasets.FederatedTorchDataset) on how to implement.  

+ 42 - 0
docs/en/tutorials/distributed_training.md

@@ -0,0 +1,42 @@
+# Tutorial 6: Distributed Training
+
+EasyFL enables federated learning (FL) training over multiple GPUs. We define the following variables to further illustrate the idea:
+* K: the number of clients who participated in training each round
+* N: the number of available GPUs
+
+When _K == N_, each selected client is allocated to a GPU to train.
+
+When _K > N_, multiple clients are allocated to a GPU, then they execute training sequentially in the GPU.
+
+When _K < N_, you can adjust to use fewer GPUs in training.
+
+We make it easy to use distributed training. You just need to modify the configs, without changing the core implementations.
+In particular, you need to set the number of GPUs in `gpu` and specific distributed settings in the `distributed` configs.
+
+The following is an example of distributed training on a GPU cluster managed by _slurm_.
+
+```python
+import easyfl
+from easyfl.distributed import slurm
+
+# Get the distributed settings.
+rank, local_rank, world_size, host_addr = slurm.setup()
+# Set the distributed training settings.
+config = {
+    "gpu": world_size,
+    "distributed": {
+        "rank": rank, 
+        "local_rank": local_rank, 
+        "world_size": world_size, 
+        "init_method": host_addr
+    },
+}
+# Initialize EasyFL.
+easyfl.init(config)
+# Execute training with distributed training.
+easyfl.run()
+```
+
+We will further provide scripts to set up distributed training using `multiprocess`. 
+Pull requests are also welcomed.
+

+ 126 - 0
docs/en/tutorials/high-level_apis.md

@@ -0,0 +1,126 @@
+# Tutorial 1: High-level APIs
+
+EasyFL provides three types of high-level APIs: **initialization**, **registration**, and **execution**.
+The initialization API initializes EasyFL with configurations. 
+Registration APIs register customized components into the platform. 
+Execution APIs start federated learning process. 
+These APIs are listed in the table below.
+
+| API Name      | Description | Category 
+| :---        |    :----:   | :--- |
+| init(config) | Initialize EasyFL with configurations | Initialization | 
+| register_dataset(train, test, val) | Register a customized dataset | Registration | 
+| register_model(model) | Register a customized model | Registration | 
+| register_server(server) | Register a customized server | Registration |
+| register_client(client) | Register a customized client | Registration |
+| run() | Start federated learning for standalone and distributed training | Execution |
+| start_server() | Start server service for remote training | Execution |
+| start_client() | Start client service for remote training | Execution |
+
+
+`init(config):` Initialize EasyFL with provided configurations (`config`) or default configurations if not specified.  
+These configurations determine the training hardware and hyperparameters.
+
+`register_<module>:` Register customized modules to the system. 
+EasyFL supports the registration of customized datasets, models, server, and client, replacing the default modules in FL training. In the experimental phase, users can register newly developed algorithms to understand their performance.
+
+`run, start_<server/client>:` The APIs are commands to trigger execution. 
+`run()` starts FL using standalone training or distributed training. 
+ `start_server` and `start_client` start the server and client services to communicate remotely with `args` variables for configurations specific to remote training, such as the endpoint addresses.
+
+Next, we introduce how to use these APIs with examples.
+
+## Standalone Training Example
+
+_**Standalone training**_ means that federated learning (FL) training is run on a single hardware device, such as your personal computer and a single GPU.
+_**Distributed training**_ means conducting FL with multiple GPUs to speed up training.
+Running distributed training is similar to standalone training, except that we need to configure the number of GPUs and the distributed settings. 
+We explain more on distributed training in [another note](distributed_training.md) and focus on standalone training example here.  
+
+To run any federated learning process, we need to first call the initialization API and then use the execution API. Registration is optional.
+
+The simplest way is to run with the default setup. 
+```python
+import easyfl
+# Initialize federated learning with default configurations.
+easyfl.init()
+# Execute federated learning training.
+easyfl.run()
+```
+
+You can run it with specified configurations. 
+```python
+import easyfl
+
+# Customized configuration.
+config = {
+    "data": {"dataset": "cifar10", "num_of_clients": 1000},
+    "server": {"rounds": 5, "clients_per_round": 2, "test_all": False},
+    "client": {"local_epoch": 5},
+    "model": "resnet18",
+    "test_mode": "test_in_server",
+}
+# Initialize federated learning with default configurations.
+easyfl.init(config)
+# Execute federated learning training.
+easyfl.run()
+```
+
+You can also run federated learning with customized datasets, model, server and client implementations.
+
+Note: `registration` must be done before `initialization`.
+
+```python
+import easyfl
+from easyfl.client import BaseClient
+
+# Inherit BaseClient to implement customized client operations.
+class CustomizedClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
+        super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
+        pass  # more initialization of attributes.
+
+    def train(self, conf, device):
+        pass # Implement customized training method, overwriting the default one.
+
+# Register customized client.
+easyfl.register_client(CustomizedClient)
+# Initialize federated learning with default configurations.
+easyfl.init()
+# Execute federated learning training.
+easyfl.run()
+```
+
+## Remote Training Example
+
+_**Remote training**_ is the scenario where the server and the clients are running on different devices.
+We explain more on remote training in [another note](remote_training.md). 
+Here we provide examples on how to start client and server services using the APIs.
+
+Start remote server.
+```python
+import easyfl
+# Configurations for the remote server.
+conf = {"is_remote": True, "local_port": 22999}
+# Initialize only the configuration.
+easyfl.init(conf, init_all=False)
+# Start remote server service.
+# The remote server waits to be connected with the remote client.
+easyfl.start_server()
+```
+
+Start remote client.
+```python
+import easyfl
+# Configurations for the remote client.
+conf = {"is_remote": True, "local_port": 23000}
+# Initialize only the configuration.
+easyfl.init(conf, init_all=False)
+# Start remote client service.
+# The remote client waits to be connected with the remote server.
+easyfl.start_client()
+```
+
+We expose two additional APIs that wrap starting remote services with customized components.
+They are `start_remote_server` and `start_remote_client`. 
+You can explore more in the API documentation.

+ 11 - 0
docs/en/tutorials/index.rst

@@ -0,0 +1,11 @@
+.. toctree::
+   :maxdepth: 2
+   :caption: Tutorial
+
+   high-level_apis.md
+   config.md
+   dataset.md
+   model.md
+   customize_server_and_client.md
+   distributed_training.md
+   remote_training.md

+ 92 - 0
docs/en/tutorials/model.md

@@ -0,0 +1,92 @@
+# Tutorial 4: Models
+
+EasyFL supports numerous models and allows you to customize the model.
+
+## Out-of-the-box Models
+
+To use these models, you can set configurations `model: <model_name>`. 
+
+We currently provide `lenet`, `resnet`, `resnet18`, `resnet50`,`vgg9`, and `rnn`.
+
+## Customized Models
+
+EasyFL allows training with a wide range of models by providing the flexibility to customize models. 
+You can customize and register models in two ways: register as a class and register as an instance.
+Either way, the basic is to **inherit and implement the `easyfl.models.BaseModel`**. 
+
+### Register as a Class
+
+In the example below, we implement and conduct FL training with a `CustomizedModel`. 
+
+It is applicable when the model does not require extra arguments to initialize.
+
+```python
+from torch import nn
+import torch.nn.functional as F
+import easyfl
+from easyfl.models import BaseModel
+
+# Define a customized model class.
+class CustomizedModel(BaseModel):
+    def __init__(self):
+        super(CustomizedModel, self).__init__()
+        self.conv1 = nn.Conv2d(3, 32, 224, padding=(2, 2))
+        self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
+        self.fc1 = nn.Linear(64, 128)
+        self.fc2 = nn.Linear(128, 42)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(-1, 64)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+
+        return x
+
+# Register the customized model class.
+easyfl.register_model(CustomizedModel)
+# Initialize EasyFL.
+easyfl.init()
+# Execute FL training.
+easyfl.run()
+```
+
+### Register as an Instance
+
+When the model requires arguments for initialization, we can implement and register a model instance. 
+
+```python
+from torch import nn
+import torch.nn.functional as F
+import easyfl
+from easyfl.models import BaseModel
+
+# Define a customized model class.
+class CustomizedModel(BaseModel):
+    def __init__(self, num_class):
+        super(CustomizedModel, self).__init__()
+        self.conv1 = nn.Conv2d(3, 32, 224, padding=(2, 2))
+        self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
+        self.fc1 = nn.Linear(64, 128)
+        self.fc2 = nn.Linear(128, num_class)
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(-1, 64)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+        return x
+
+# Register the customized model instance.
+easyfl.register_model(CustomizedModel(num_class=10))
+# Initialize EasyFL.
+easyfl.init()
+# Execute FL training.
+easyfl.run()
+```

+ 257 - 0
docs/en/tutorials/remote_training.md

@@ -0,0 +1,257 @@
+# Tutorial 7: Remote Training
+
+_**Remote training**_ is the scenario where the server and the clients are running on different devices. Standalone and
+distributed training are mainly for federated learning (FL) simulation experiments. Remote training brings FL from
+experimentation to production.
+
+## Remote Training Example
+
+In remote training, both server and clients are started as gRPC services. Here we provide examples on how to start
+server and client services.
+
+Start remote server.
+
+```python
+import easyfl
+
+# Configurations for the remote server.
+conf = {"is_remote": True, "local_port": 22999}
+# Initialize only the configuration.
+easyfl.init(conf, init_all=False)
+# Start remote server service.
+# The remote server waits to be connected with the remote client.
+easyfl.start_server()
+```
+
+Start remote client 1 with port 23000.
+
+```python
+import easyfl
+
+# Configurations for the remote client.
+conf = {
+    "is_remote": True,
+    "local_port": 23000,
+    "server_addr": "localhost:22999",
+    "index": 0,
+}
+# Initialize only the configuration.
+easyfl.init(conf, init_all=False)
+# Start remote client service.
+# The remote client waits to be connected with the remote server.
+easyfl.start_client()
+```
+
+Start remote client 2 with port 23001.
+
+```python
+import easyfl
+
+# Configurations for the remote client.
+conf = {
+    "is_remote": True,
+    "local_port": 23001,
+    "server_addr": "localhost:22999",
+    "index": 1,
+}
+# Initialize only the configuration.
+easyfl.init(conf, init_all=False)
+# Start remote client service.
+# The remote client waits to be connected with the remote server.
+easyfl.start_client()
+```
+
+The client service connects to the remote service via specified `server_address`. 
+The client service users `index` to decide the data (user) of the configured dataset.
+
+To trigger remote training, we can send gRPC requests to trigger the training operation.
+
+```python
+import easyfl
+from easyfl.pb import common_pb2 as common_pb
+from easyfl.pb import server_service_pb2 as server_pb
+from easyfl.protocol import codec
+from easyfl.communication import grpc_wrapper
+from easyfl.registry.vclient import VirtualClient
+
+server_addr = "localhost:22999"
+config = {
+    "data": {"dataset": "femnist"},
+    "model": "lenet",
+    "test_mode": "test_in_client"
+}
+# Initialize configurations.
+easyfl.init(config, init_all=False)
+# Initialize the model, using the configured 'lenet'
+model = easyfl.init_model()
+
+# Construct gRPC request 
+stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, server_addr)
+request = server_pb.RunRequest(model=codec.marshal(model))
+# The request contains clients' addresses for the server to communicate with the clients.
+clients = [VirtualClient("1", "localhost:23000", 0), VirtualClient("2", "localhost:23001", 1)]
+for c in clients:
+    request.clients.append(server_pb.Client(client_id=c.id, index=c.index, address=c.address))
+# Send request to trigger training.
+response = stub.Run(request)
+result = "Success" if response.status.code == common_pb.SC_OK else response
+print(result)
+```
+
+Similarly, we can also stop remote training by sending gRPC requests to the server.
+
+```python
+from easyfl.communication import grpc_wrapper
+from easyfl.pb import common_pb2 as common_pb
+from easyfl.pb import server_service_pb2 as server_pb
+
+server_addr = "localhost:22999"
+stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, server_addr)
+# Send request to stop training.
+response = stub.Stop(server_pb.StopRequest())
+result = "Success" if response.status.code == common_pb.SC_OK else response
+print(result)
+```
+
+## Remote Training on Docker and Kubernetes
+
+EasyFL supports deployment of FL training using Docker and Kubernetes.
+
+Since we cannot easily obtain the server and client addresses in Docker or Kubernetes, especially when scaling up the number of clients,
+EasyFL provides a service discovery mechanism, as shown in the image below.
+![service_discovery](../_static/image/registry.png)
+
+It contains registors to dynamically register the clients and the registry to store the client addresses for the server to query. 
+The registor gets the addresses of clients and registers them to the registry. 
+Since the clients are unaware of the container environment they are running, 
+they must rely on a third-party service (the registor) to fetch their container addresses to complete registration. 
+The registry stores the registered client addresses for the server to query. 
+EasyFL supports two service discovery methods targeting different deployment scenarios: using Docker and using Kubernetes
+
+The following are the deployment manual and the steps to conduct training in Kubernetes.
+
+⚠️ Note: these commands were tested before refactoring. They may not work as expected now. **Need further testing**. 
+
+### Deployment using Docker
+
+Important: Adjust the `Memeory` constrain of docker to be > 11 GB (To be optimized)
+
+1. Build docker images and start services with either docker compose or individual docker containers
+2. Start training with a grpc message
+
+#### Build images
+
+```
+make base_image
+make image
+```
+
+Or
+
+```
+docker build -t easyfl:base -f docker/base.Dockerfile .
+docker build -t easyfl-client -f docker/client.Dockerfile .
+docker build -t easyfl-server -f docker/server.Dockerfile .
+docker build -t easyfl-run -f docker/run.Dockerfile .
+```
+
+#### Start with Docker Compose
+
+Use docker compose to start all services.
+```
+docker-compose up --scale client=2 && docker-compose rm -fsv
+```
+
+Mac users with Docker Desktop > 2.0 may have port conflict occurs because `bind: address already in use`.
+The workaround is to run with 
+```
+docker-compose up && docker-compose rm -fsv
+``` 
+and start another terminal to scale with 
+```
+docker-compose up --scale client=2 && docker-compose rm -fsv
+```
+
+#### Etcd Setup
+
+```
+export NODE1=localhost
+export DATA_DIR="etcd-data"
+REGISTRY=quay.io/coreos/etcd
+
+docker run --rm \
+  -p 23790:2379 \
+  -p 23800:2380 \
+  --volume=${DATA_DIR}:/etcd-data \
+  --name etcd ${REGISTRY}:v3.4.0 \
+  /usr/local/bin/etcd \
+  --data-dir=/etcd-data --name node1 \
+  --initial-advertise-peer-urls http://${NODE1}:2380 --listen-peer-urls http://0.0.0.0:2380 \
+  --advertise-client-urls http://${NODE1}:2379 --listen-client-urls http://0.0.0.0:2379 \
+  --initial-cluster node1=http://${NODE1}:2380
+```
+
+#### Docker Register
+
+```
+docker run --name docker-register --rm -d -e HOST_IP=<172.18.0.1> -e ETCD_HOST=<172.17.0.1>:2379 -v /var/run/docker.sock:/var/run/docker.sock -t wingalong/docker-register
+```
+* HOST_IP: the ip address of network client runs on: gateway in `docker inspect easyfl-client` 
+* ETCD_HOST: the ip address of etcd: gateway in `docker inspect etcd`
+
+#### Start containers
+
+```shell
+# 1. Start clients
+docker run --rm -p 23400:23400 --name client0 --network host -v <dataset_path>/femnist/data:/app/<dataset_path>/femnist/data easyfl-client --index=0 --is-remote=True --local-port=23400 --server-addr="localhost:23501"
+docker run --rm -p 23401:23401 --name client1 --network host -v <dataset_path>/femnist/data:/app/<dataset_path>/femnist/data easyfl-client --index=1 --is-remote=True --local-port=23401 --server-addr="localhost:23501"
+
+# 2. Start server
+docker run --rm -p 23501:23501 --name easyfl-server --network host  easyfl-server --local-port=23501 --is-remote=True
+```
+
+Note: you need to replace the `dataset_path` with your actual dataset directory.
+
+#### Start Training Remotely 
+```
+docker run --rm --name easyfl-run --network host easyfl-run --server-addr 127.0.0.1:23501 --etcd-addr:127.0.0.1:23790
+```
+It sends a gRPC message to server to start training.
+
+### Deployment using Kubernetes
+
+
+```shell
+# 1. Deploy tracker
+kubectl apply -f kubernetes/tracker.yml
+
+# 2. Deploy server
+kubectl apply -f kubernetes/server.yml
+
+# 3. Deploy client
+kubectl apply -f kubernetes/client.yml
+
+# 4. Scale client
+kubectl scale -n easyfl deployment easyfl-client --replicas=6
+
+# 5. Check pods
+kubectl get pods -n easyfl -o wide
+
+# 6. Run
+
+python examples/remote_run.py --server-addr localhost:32501 --source kubernetes
+
+# 7. Check logs
+kubectl logs -f -n easyfl easyfl-server
+
+# 8. Get results
+python examples/test_services.py --task-id task_ijhwqg
+
+# 9. Save log
+kubectl logs -n easyfl easyfl-server > server-log.log
+
+# 10. Stop client/server/tracker
+kubectl delete -f kubernetes/client.yml
+kubectl delete -f kubernetes/server.yml
+kubectl delete -f kubernetes/tracker.yml
+```

+ 22 - 0
easyfl/__init__.py

@@ -0,0 +1,22 @@
+from easyfl.coordinator import (
+    init,
+    init_dataset,
+    init_model,
+    start_server,
+    start_client,
+    run,
+    register_dataset,
+    register_model,
+    register_server,
+    register_client,
+    load_config,
+)
+
+from easyfl.service import (
+    start_remote_server,
+    start_remote_client,
+)
+
+__all__ = ["init", "init_dataset", "init_model", "start_server", "start_client", "run",
+           "register_dataset", "register_model", "register_server", "register_client",
+           "load_config", "start_remote_server", "start_remote_client"]

+ 5 - 0
easyfl/client/__init__.py

@@ -0,0 +1,5 @@
+from easyfl.client.base import BaseClient
+from easyfl.client.service import ClientService
+
+__all__ = ['BaseClient', 'ClientService']
+

+ 471 - 0
easyfl/client/base.py

@@ -0,0 +1,471 @@
+import argparse
+import copy
+import logging
+import time
+
+import torch
+
+from easyfl.client.service import ClientService
+from easyfl.communication import grpc_wrapper
+from easyfl.distributed.distributed import CPU
+from easyfl.pb import common_pb2 as common_pb
+from easyfl.pb import server_service_pb2 as server_pb
+from easyfl.protocol import codec
+from easyfl.tracking import metric
+from easyfl.tracking.client import init_tracking
+from easyfl.tracking.evaluation import model_size
+
+logger = logging.getLogger(__name__)
+
+
+def create_argument_parser():
+    """Create argument parser with arguments/configurations for starting remote client service.
+
+    Returns:
+        argparse.ArgumentParser: Parser with client service arguments.
+    """
+    parser = argparse.ArgumentParser(description='Federated Client')
+    parser.add_argument('--local-port',
+                        type=int,
+                        default=23000,
+                        help='Listen port of the client')
+    parser.add_argument('--server-addr',
+                        type=str,
+                        default="localhost:22999",
+                        help='Address of server in [IP]:[PORT] format')
+    parser.add_argument('--tracker-addr',
+                        type=str,
+                        default="localhost:12666",
+                        help='Address of tracking service in [IP]:[PORT] format')
+    parser.add_argument('--is-remote',
+                        type=bool,
+                        default=False,
+                        help='Whether start as a remote client.')
+    return parser
+
+
+class BaseClient(object):
+    """Default implementation of federated learning client.
+
+    Args:
+        cid (str): Client id.
+        conf (omegaconf.dictconfig.DictConfig): Client configurations.
+        train_data (:obj:`FederatedDataset`): Training dataset.
+        test_data (:obj:`FederatedDataset`): Test dataset.
+        device (str): Hardware device for training, cpu or cuda devices.
+        sleep_time (float): Duration of on hold after training to simulate stragglers.
+        is_remote (bool): Whether start remote training.
+        local_port (int): Port of remote client service.
+        server_addr (str): Remote server service grpc address.
+        tracker_addr (str): Remote tracking service grpc address.
+
+
+    Override the class and functions to implement customized client.
+
+    Example:
+        >>> from easyfl.client import BaseClient
+        >>> class CustomizedClient(BaseClient):
+        >>>     def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
+        >>>         super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
+        >>>         pass  # more initialization of attributes.
+        >>>
+        >>>     def train(self, conf, device=CPU):
+        >>>         # Implement customized client training method, which overwrites the default training method.
+        >>>         pass
+    """
+    def __init__(self,
+                 cid,
+                 conf,
+                 train_data,
+                 test_data,
+                 device,
+                 sleep_time=0,
+                 is_remote=False,
+                 local_port=23000,
+                 server_addr="localhost:22999",
+                 tracker_addr="localhost:12666"):
+        self.cid = cid
+        self.conf = conf
+        self.train_data = train_data
+        self.train_loader = None
+        self.test_data = test_data
+        self.test_loader = None
+        self.device = device
+
+        self.round_time = 0
+        self.train_time = 0
+        self.test_time = 0
+
+        self.train_accuracy = []
+        self.train_loss = []
+        self.test_accuracy = 0
+        self.test_loss = 0
+
+        self.profiled = False
+        self._sleep_time = sleep_time
+
+        self.compressed_model = None
+        self.model = None
+        self._upload_holder = server_pb.UploadContent()
+
+        self.is_remote = is_remote
+        self.local_port = local_port
+        self._server_addr = server_addr
+        self._tracker_addr = tracker_addr
+        self._server_stub = None
+        self._tracker = None
+        self._is_train = True
+
+        if conf.track:
+            self._tracker = init_tracking(init_store=False)
+
+    def run_train(self, model, conf):
+        """Conduct training on clients.
+
+        Args:
+            model (nn.Module): Model to train.
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+        Returns:
+            :obj:`UploadRequest`: Training contents. Unify the interface for both local and remote operations.
+        """
+        self.conf = conf
+        if conf.track:
+            self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid)
+
+        self._is_train = True
+
+        self.download(model)
+        self.track(metric.TRAIN_DOWNLOAD_SIZE, model_size(model))
+
+        self.decompression()
+
+        self.pre_train()
+        self.train(conf, self.device)
+        self.post_train()
+
+        self.track(metric.TRAIN_ACCURACY, self.train_accuracy)
+        self.track(metric.TRAIN_LOSS, self.train_loss)
+        self.track(metric.TRAIN_TIME, self.train_time)
+
+        if conf.local_test:
+            self.test_local()
+
+        self.compression()
+
+        self.track(metric.TRAIN_UPLOAD_SIZE, model_size(self.compressed_model))
+
+        self.encryption()
+
+        return self.upload()
+
+    def run_test(self, model, conf):
+        """Conduct testing on clients.
+
+        Args:
+            model (nn.Module): Model to test.
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+        Returns:
+            :obj:`UploadRequest`: Testing contents. Unify the interface for both local and remote operations.
+        """
+        self.conf = conf
+        if conf.track:
+            reset = not self._is_train
+            self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid, reset_client=reset)
+
+        self._is_train = False
+
+        self.download(model)
+        self.track(metric.TEST_DOWNLOAD_SIZE, model_size(model))
+
+        self.decompression()
+
+        self.pre_test()
+        self.test(conf, self.device)
+        self.post_test()
+
+        self.track(metric.TEST_ACCURACY, float(self.test_accuracy))
+        self.track(metric.TEST_LOSS, float(self.test_loss))
+        self.track(metric.TEST_TIME, self.test_time)
+
+        return self.upload()
+
+    def download(self, model):
+        """Download model from the server.
+
+        Args:
+            model (nn.Module): Global model distributed from the server.
+        """
+        if self.compressed_model:
+            self.compressed_model.load_state_dict(model.state_dict())
+        else:
+            self.compressed_model = copy.deepcopy(model)
+
+    def decompression(self):
+        """Decompressed model. It can be further implemented when the model is compressed in the server."""
+        self.model = self.compressed_model
+
+    def pre_train(self):
+        """Preprocessing before training."""
+        pass
+
+    def train(self, conf, device=CPU):
+        """Execute client training.
+
+        Args:
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+            device (str): Hardware device for training, cpu or cuda devices.
+        """
+        start_time = time.time()
+        loss_fn, optimizer = self.pretrain_setup(conf, device)
+        self.train_loss = []
+        for i in range(conf.local_epoch):
+            batch_loss = []
+            for batched_x, batched_y in self.train_loader:
+                x, y = batched_x.to(device), batched_y.to(device)
+                optimizer.zero_grad()
+                out = self.model(x)
+                loss = loss_fn(out, y)
+                loss.backward()
+                optimizer.step()
+                batch_loss.append(loss.item())
+            current_epoch_loss = sum(batch_loss) / len(batch_loss)
+            self.train_loss.append(float(current_epoch_loss))
+            logger.debug("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
+        self.train_time = time.time() - start_time
+        logger.debug("Client {}, Train Time: {}".format(self.cid, self.train_time))
+
+    def post_train(self):
+        """Postprocessing after training."""
+        pass
+
+    def pretrain_setup(self, conf, device):
+        """Setup loss function and optimizer before training."""
+        self.simulate_straggler()
+        self.model.train()
+        self.model.to(device)
+        loss_fn = self.load_loss_fn(conf)
+        optimizer = self.load_optimizer(conf)
+        if self.train_loader is None:
+            self.train_loader = self.load_loader(conf)
+        return loss_fn, optimizer
+
+    def load_loss_fn(self, conf):
+        return torch.nn.CrossEntropyLoss()
+
+    def load_optimizer(self, conf):
+        """Load training optimizer. Implemented Adam and SGD."""
+        if conf.optimizer.type == "Adam":
+            optimizer = torch.optim.Adam(self.model.parameters(), lr=conf.optimizer.lr)
+        else:
+            # default using optimizer SGD
+            optimizer = torch.optim.SGD(self.model.parameters(),
+                                        lr=conf.optimizer.lr,
+                                        momentum=conf.optimizer.momentum,
+                                        weight_decay=conf.optimizer.weight_decay)
+        return optimizer
+
+    def load_loader(self, conf):
+        """Load the training data loader.
+
+        Args:
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+        Returns:
+            torch.utils.data.DataLoader: Data loader.
+        """
+        return self.train_data.loader(conf.batch_size, self.cid, shuffle=True, seed=conf.seed)
+
+    def test_local(self):
+        """Test client local model after training."""
+        pass
+
+    def pre_test(self):
+        """Preprocessing before testing."""
+        pass
+
+    def test(self, conf, device=CPU):
+        """Execute client testing.
+
+        Args:
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+            device (str): Hardware device for training, cpu or cuda devices.
+        """
+        begin_test_time = time.time()
+        self.model.eval()
+        self.model.to(device)
+        loss_fn = self.load_loss_fn(conf)
+        if self.test_loader is None:
+            self.test_loader = self.test_data.loader(conf.test_batch_size, self.cid, shuffle=False, seed=conf.seed)
+        # TODO: make evaluation metrics a separate package and apply it here.
+        self.test_loss = 0
+        correct = 0
+        with torch.no_grad():
+            for batched_x, batched_y in self.test_loader:
+                x = batched_x.to(device)
+                y = batched_y.to(device)
+                log_probs = self.model(x)
+                loss = loss_fn(log_probs, y)
+                _, y_pred = torch.max(log_probs, -1)
+                correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
+                self.test_loss += loss.item()
+            test_size = self.test_data.size(self.cid)
+            self.test_loss /= test_size
+            self.test_accuracy = 100.0 * float(correct) / test_size
+
+        logger.debug('Client {}, testing -- Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
+            self.cid, self.test_loss, correct, test_size, self.test_accuracy))
+
+        self.test_time = time.time() - begin_test_time
+        self.model = self.model.cpu()
+
+    def post_test(self):
+        """Postprocessing after testing."""
+        pass
+
+    def encryption(self):
+        """Encrypt the client local model."""
+        # TODO: encryption of model, remember to track encrypted model instead of compressed one after implementation.
+        pass
+
+    def compression(self):
+        """Compress the client local model after training and before uploading to the server."""
+        self.compressed_model = self.model
+
+    def upload(self):
+        """Upload the messages from client to the server.
+
+        Returns:
+            :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
+                Only applicable for local training as remote training upload through a gRPC request.
+        """
+        request = self.construct_upload_request()
+        if not self.is_remote:
+            self.post_upload()
+            return request
+
+        self.upload_remotely(request)
+        self.post_upload()
+
+    def post_upload(self):
+        """Postprocessing after uploading training/testing results."""
+        pass
+
+    def construct_upload_request(self):
+        """Construct client upload request for training updates and testing results.
+
+        Returns:
+            :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
+        """
+        data = codec.marshal(server_pb.Performance(accuracy=self.test_accuracy, loss=self.test_loss))
+        typ = common_pb.DATA_TYPE_PERFORMANCE
+        try:
+            if self._is_train:
+                data = codec.marshal(copy.deepcopy(self.compressed_model))
+                typ = common_pb.DATA_TYPE_PARAMS
+                data_size = self.train_data.size(self.cid)
+            else:
+                data_size = 1 if not self.test_data else self.test_data.size(self.cid)
+        except KeyError:
+            # When the datasize cannot be get from dataset, default to use equal aggregate
+            data_size = 1
+
+        m = self._tracker.get_client_metric().to_proto() if self._tracker else common_pb.ClientMetric()
+        return server_pb.UploadRequest(
+            task_id=self.conf.task_id,
+            round_id=self.conf.round_id,
+            client_id=self.cid,
+            content=server_pb.UploadContent(
+                data=data,
+                type=typ,
+                data_size=data_size,
+                metric=m,
+            ),
+        )
+
+    def upload_remotely(self, request):
+        """Send upload request to remote server via gRPC.
+
+        Args:
+            request (:obj:`UploadRequest`): Upload request.
+        """
+        start_time = time.time()
+
+        self.connect_to_server()
+        resp = self._server_stub.Upload(request)
+
+        upload_time = time.time() - start_time
+        m = metric.TRAIN_UPLOAD_TIME if self._is_train else metric.TEST_UPLOAD_TIME
+        self.track(m, upload_time)
+
+        logger.info("client upload time: {}s".format(upload_time))
+        if resp.status.code == common_pb.SC_OK:
+            logger.info("Uploaded remotely to the server successfully\n")
+        else:
+            logger.error("Failed to upload, code: {}, message: {}\n".format(resp.status.code, resp.status.message))
+
+    # Functions for remote services.
+
+    def start_service(self):
+        """Start client service."""
+        if self.is_remote:
+            grpc_wrapper.start_service(grpc_wrapper.TYPE_CLIENT, ClientService(self), self.local_port)
+
+    def connect_to_server(self):
+        """Establish connection between the client and the server."""
+        if self.is_remote and self._server_stub is None:
+            self._server_stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, self._server_addr)
+            logger.info("Successfully connected to gRPC server {}".format(self._server_addr))
+
+    def operate(self, model, conf, index, is_train=True):
+        """A wrapper over operations (training/testing) on clients.
+
+        Args:
+            model (nn.Module): Model for operations.
+            conf (omegaconf.dictconfig.DictConfig): Client configurations.
+            index (int): Client index in the client list, for retrieving data. TODO: improvement.
+            is_train (bool): The flag to indicate whether the operation is training, otherwise testing.
+        """
+        try:
+            # Load the data index depending on server request
+            self.cid = self.train_data.users[index]
+        except IndexError:
+            logger.error("Data index exceed the available data, abort training")
+            return
+
+        if self.conf.track and self._tracker is None:
+            self._tracker = init_tracking(init_store=False)
+
+        if is_train:
+            logger.info("Train on data index {}, client: {}".format(index, self.cid))
+            self.run_train(model, conf)
+        else:
+            logger.info("Test on data index {}, client: {}".format(index, self.cid))
+            self.run_test(model, conf)
+
+    # Functions for tracking.
+
+    def track(self, metric_name, value):
+        """Track a metric.
+
+        Args:
+            metric_name (str): The name of the metric.
+            value (str|int|float|bool|dict|list): The value of the metric.
+        """
+        if not self.conf.track or self._tracker is None:
+            logger.debug("Tracker not available, Tracking not supported")
+            return
+        self._tracker.track_client(metric_name, value)
+
+    def save_metrics(self):
+        """Save client metrics to database."""
+        # TODO: not tested
+        if self._tracker is None:
+            logger.debug("Tracker not available, no saving")
+            return
+        self._tracker.save_client()
+
+    # Functions for simulation.
+
+    def simulate_straggler(self):
+        """Simulate straggler effect of system heterogeneity."""
+        if self._sleep_time > 0:
+            time.sleep(self._sleep_time)

+ 30 - 0
easyfl/client/service.py

@@ -0,0 +1,30 @@
+import logging
+import threading
+
+from easyfl.pb import client_service_pb2_grpc as client_grpc, client_service_pb2 as client_pb, common_pb2 as common_pb
+from easyfl.protocol import codec
+
+logger = logging.getLogger(__name__)
+
+
+class ClientService(client_grpc.ClientServiceServicer):
+    """"Remote gRPC client service.
+
+    Args:
+        client (:obj:`BaseClient`): Federated learning client instance.
+    """
+    def __init__(self, client):
+        self._base = client
+
+    def Operate(self, request, context):
+        """Perform training/testing operations."""
+        # TODO: add request validation.
+        model = codec.unmarshal(request.model)
+        is_train = request.type == client_pb.OP_TYPE_TRAIN
+        # Threading is necessary to respond to server quickly
+        t = threading.Thread(target=self._base.operate, args=[model, request.config, request.data_index, is_train])
+        t.start()
+        response = client_pb.OperateResponse(
+            status=common_pb.Status(code=common_pb.SC_OK),
+        )
+        return response

+ 3 - 0
easyfl/communication/__init__.py

@@ -0,0 +1,3 @@
+from easyfl.communication.grpc_wrapper import *
+
+__all__ = ['init_stub', 'start_service']

+ 77 - 0
easyfl/communication/grpc_wrapper.py

@@ -0,0 +1,77 @@
+from concurrent import futures
+
+import grpc
+
+from easyfl.pb import client_service_pb2_grpc as client_grpc
+from easyfl.pb import server_service_pb2_grpc as server_grpc
+from easyfl.pb import tracking_service_pb2_grpc as tracking_grpc
+
+MAX_MESSAGE_LENGTH = 524288000  # 500MB
+
+TYPE_CLIENT = "client"
+TYPE_SERVER = "server"
+TYPE_TRACKING = "tracking"
+
+
+def init_stub(typ, address):
+    """Initialize gRPC stub.
+
+    Args:
+        typ (str): Type of service, option: client, server, tracking
+        address (str): Address of the gRPC service.
+    Returns:
+        (:obj:`ClientServiceStub`|:obj:`ServerServiceStub`|:obj:`TrackingServiceStub`): stub of the gRPC service.
+    """
+
+    channel = grpc.insecure_channel(
+        address,
+        options=[
+            ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
+            ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
+        ],
+    )
+    if typ == TYPE_CLIENT:
+        stub = client_grpc.ClientServiceStub(channel)
+    elif typ == TYPE_TRACKING:
+        stub = tracking_grpc.TrackingServiceStub(channel)
+    else:
+        stub = server_grpc.ServerServiceStub(channel)
+
+    return stub
+
+
+def start_service(typ, service, port):
+    """Start gRPC service.
+    Args:
+        typ (str): Type of service, option: client, server, tracking.
+        service (:obj:`ClientService`|:obj:`ServerService`|:obj:`TrackingService`): gRPC service to start.
+        port (int): The port of the service.
+    """
+    server = grpc.server(
+        futures.ThreadPoolExecutor(max_workers=10),
+        options=[
+            ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
+            ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
+        ],
+    )
+    if typ == TYPE_CLIENT:
+        client_grpc.add_ClientServiceServicer_to_server(service, server)
+    elif typ == TYPE_TRACKING:
+        tracking_grpc.add_TrackingServiceServicer_to_server(service, server)
+    else:
+        server_grpc.add_ServerServiceServicer_to_server(service, server)
+    server.add_insecure_port('[::]:{}'.format(port))
+    server.start()
+    server.wait_for_termination()
+
+
+def endpoint(host, port):
+    """Format endpoint.
+
+    Args:
+        host (str): Host address.
+        port (int): Port number.
+    Returns:
+        str: Address in `host:port` format.
+    """
+    return "{}:{}".format(host, port)

+ 0 - 0
easyfl/compression/__init__.py


+ 113 - 0
easyfl/config.yaml

@@ -0,0 +1,113 @@
+# The unique identifier for each federated learning task
+task_id: ""
+
+# Provide dataset and federated learning simulation related configuration.
+data:
+  # The root directory where datasets are stored.
+  root: "./data/"
+  # The name of the dataset, support: femnist, shakespeare, cifar10, and cifar100.
+  dataset: femnist
+  # The data distribution of each client, support: iid, niid (for femnist and shakespeare), and dir and class (for cifar datasets).
+    # `iid` means independent and identically distributed data.
+    # `niid` means non-independent and identically distributed data for FEMNIST and Shakespeare.
+    # `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+    # `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+  split_type: "iid"
+
+  # The minimal number of samples in each client. It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+  min_size: 10
+  # The fraction of data sampled for LEAF datasets. e.g., 10% means that only 10% of total dataset size are used.
+  data_amount: 0.05
+  # The fraction of the number of clients used when the split_type is 'iid'.
+  iid_fraction: 0.1
+  # Whether partition users of the dataset into train-test groups. Only applicable to femnist and shakespeare datasets.
+    # True means partitioning users of the dataset into train-test groups.
+    # False means partitioning each users' samples into train-test groups.
+  user: False
+  # The fraction of data for training; the rest are for testing.
+  train_test_split: 0.9
+
+  # The number of classes in each client. Only applicable when the split_type is 'class'.
+  class_per_client: 1
+  # The targeted number of clients to construct.used in non-leaf dataset, number of clients split into. for leaf dataset, only used when split type class.
+  num_of_clients: 100
+  # The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir` for CIFAR datasets.
+  alpha: 0.5
+
+  # The targeted distribution of quantities to simulate data quantity heterogeneity.
+    # The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+    # The `num_of_clients` should be divisible by `len(weights)`.
+    # None means clients are simulated with the same data quantity.
+  weights: NULL
+
+# The name of the model for training, support: lenet, rnn, resnet, resnet18, resnet50, vgg9.
+model: lenet
+# How to conduct testing, options: test_in_client or test_in_server.
+  # `test_in_client` means that each client has a test set to run testing.
+  # `test_in_server` means that server has a test set to run testing for the global model. Use this mode for cifar datasets.
+test_mode: "test_in_client"
+# The way to measure testing performance (accuracy) when test mode is `test_in_client`, support: average or weighted (means weighted average).
+test_method: "average"
+
+server:
+  track: False  # Whether track server metrics using the tracking service.
+  rounds: 10  # Total training round.
+  clients_per_round: 5  # The number of clients to train in each round.
+  test_every: 1  # The frequency of testing: conduct testing every N round.
+  save_model_every: 10  # The frequency of saving model: save model every N round.
+  save_model_path: ""  # The path to save model. Default path is root directory of the library.
+  batch_size: 32  # The batch size of test_in_server.
+  test_all: False  # Whether test all clients or only selected clients.
+  random_selection: True  # Whether select clients to train randomly.
+  # The strategy to aggregate client uploaded models, options: FedAvg, equal.
+    # FedAvg aggregates models using weighted average, where the weights are data size of clients.
+    # equal aggregates model by simple averaging.
+  aggregation_stragtegy: "FedAvg"
+  # The content of aggregation, options: all, parameters.
+    # all means aggregating models using state_dict, including both model parameters and persistent buffers like BatchNorm stats.
+    # parameters means aggregating only model parameters.
+  aggregation_content: "all"
+
+client:
+  track: False  # Whether track server metrics using the tracking service.
+  batch_size: 32  # The batch size of training in client.
+  test_batch_size: 5  # The batch size of testing in client.
+  local_epoch: 10  # The number of epochs to train in each round.
+  optimizer:
+    type: "Adam"  # The name of the optimizer, options: Adam, SGD.
+    lr: 0.001
+    momentum: 0.9
+    weight_decay: 0
+  seed: 0
+  local_test: False  # Whether test the trained models in clients before uploading them to the server.
+
+gpu: 0  # The total number of GPUs used in training. 0 means CPU.
+distributed:  # The distributed training configurations. It is only applicable when gpu > 1.
+  backend: "nccl"  # The distributed backend.
+  init_method: ""
+  world_size: 0
+  rank: 0
+  local_rank: 0
+
+tracking:  # The configurations for logging and tracking.
+  database: ""  # The path of local dataset, sqlite3.
+  log_file: ""
+  log_level: "INFO"  # The level of logging.
+  metric_file: ""
+  save_every: 1
+
+# The configuration for system heterogeneity simulation.
+resource_heterogeneous:
+  simulate: False  # Whether simulate system heterogeneity in federated learning.
+  # The type of heterogeneity to simulate, support iso, dir, real.
+    # iso means that
+  hetero_type: "real"
+  level: 3  # The level of heterogeneous (0-5), 0 means no heterogeneous among clients.
+  sleep_group_num: 1000  # The number of groups with different sleep time. 1 means all clients are the same.
+  total_time: 1000  # The total sleep time of all clients, unit: second.
+  fraction: 1  # The fraction of clients attending heterogeneous simulation.
+  grouping_strategy: "greedy"  # The grouping strategy to handle system heterogeneity, support: random, greedy, slowest.
+  initial_default_time: 5  # The estimated default training time for each training round, unit: second.
+  default_time_momentum: 0.2  # The default momentum for default time update.
+
+seed: 0  # The random seed.

+ 481 - 0
easyfl/coordinator.py

@@ -0,0 +1,481 @@
+import logging
+import os
+import random
+import sys
+import time
+from os import path
+
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+
+from easyfl.client.base import BaseClient
+from easyfl.datasets import TEST_IN_SERVER
+from easyfl.datasets.data import construct_datasets
+from easyfl.distributed import dist_init, get_device
+from easyfl.models.model import load_model
+from easyfl.server.base import BaseServer
+from easyfl.simulation.system_hetero import resource_hetero_simulation
+
+logger = logging.getLogger(__name__)
+
+
+class Coordinator(object):
+    """Coordinator manages federated learning server and client.
+    A single instance of coordinator is initialized for each federated learning task
+    when the package is imported.
+    """
+
+    def __init__(self):
+        self.registered_model = False
+        self.registered_dataset = False
+        self.registered_server = False
+        self.registered_client = False
+        self.train_data = None
+        self.test_data = None
+        self.val_data = None
+        self.conf = None
+        self.model = None
+        self._model_class = None
+        self.server = None
+        self._server_class = None
+        self.clients = None
+        self._client_class = None
+        self.tracker = None
+
+    def init(self, conf, init_all=False):
+        """Initialize coordinator
+
+        Args:
+            conf (omegaconf.dictconfig.DictConfig): Internal configurations for federated learning.
+            init_all (bool): Whether initialize dataset, model, server, and client other than configuration.
+        """
+        self.init_conf(conf)
+
+        _set_random_seed(conf.seed)
+
+        if init_all:
+            self.init_dataset()
+
+            self.init_model()
+
+            self.init_server()
+
+            self.init_clients()
+
+    def run(self):
+        """Run the coordinator and the federated learning process.
+        Initialize `torch.distributed` if distributed training is configured.
+        """
+        start_time = time.time()
+
+        if self.conf.is_distributed:
+            dist_init(
+                self.conf.distributed.backend,
+                self.conf.distributed.init_method,
+                self.conf.distributed.world_size,
+                self.conf.distributed.rank,
+                self.conf.distributed.local_rank,
+            )
+        self.server.start(self.model, self.clients)
+        self.print_("Total training time {:.1f}s".format(time.time() - start_time))
+
+    def init_conf(self, conf):
+        """Initialize coordinator configuration.
+
+        Args:
+            conf (omegaconf.dictconfig.DictConfig): Configurations.
+        """
+        self.conf = conf
+        self.conf.is_distributed = (self.conf.gpu > 1)
+        if self.conf.gpu == 0:
+            self.conf.device = "cpu"
+        elif self.conf.gpu == 1:
+            self.conf.device = 0
+        else:
+            self.conf.device = get_device(self.conf.gpu, self.conf.distributed.world_size,
+                                          self.conf.distributed.local_rank)
+        self.print_("Configurations: {}".format(self.conf))
+
+    def init_dataset(self):
+        """Initialize datasets. Use provided datasets if not registered."""
+        if self.registered_dataset:
+            return
+        self.train_data, self.test_data = construct_datasets(self.conf.data.root,
+                                                             self.conf.data.dataset,
+                                                             self.conf.data.num_of_clients,
+                                                             self.conf.data.split_type,
+                                                             self.conf.data.min_size,
+                                                             self.conf.data.class_per_client,
+                                                             self.conf.data.data_amount,
+                                                             self.conf.data.iid_fraction,
+                                                             self.conf.data.user,
+                                                             self.conf.data.train_test_split,
+                                                             self.conf.data.weights,
+                                                             self.conf.data.alpha)
+
+        self.print_(f"Total training data amount: {self.train_data.total_size()}")
+        self.print_(f"Total testing data amount: {self.test_data.total_size()}")
+
+    def init_model(self):
+        """Initialize model instance."""
+        if not self.registered_model:
+            self._model_class = load_model(self.conf.model)
+
+        # model_class is None means model is registered as instance, no need initialization
+        if self._model_class:
+            self.model = self._model_class()
+
+    def init_server(self):
+        """Initialize a server instance."""
+        if not self.registered_server:
+            self._server_class = BaseServer
+
+        kwargs = {
+            "is_remote": self.conf.is_remote,
+            "local_port": self.conf.local_port
+        }
+
+        if self.conf.test_mode == TEST_IN_SERVER:
+            kwargs["test_data"] = self.test_data
+            if self.val_data:
+                kwargs["val_data"] = self.val_data
+
+        self.server = self._server_class(self.conf, **kwargs)
+
+    def init_clients(self):
+        """Initialize client instances, each represent a federated learning client."""
+        if not self.registered_client:
+            self._client_class = BaseClient
+
+        # Enforce system heterogeneity of clients.
+        sleep_time = [0 for _ in self.train_data.users]
+        if self.conf.resource_heterogeneous.simulate:
+            sleep_time = resource_hetero_simulation(self.conf.resource_heterogeneous.fraction,
+                                                    self.conf.resource_heterogeneous.hetero_type,
+                                                    self.conf.resource_heterogeneous.sleep_group_num,
+                                                    self.conf.resource_heterogeneous.level,
+                                                    self.conf.resource_heterogeneous.total_time,
+                                                    len(self.train_data.users))
+
+        client_test_data = self.test_data
+        if self.conf.test_mode == TEST_IN_SERVER:
+            client_test_data = None
+
+        self.clients = [self._client_class(u,
+                                           self.conf.client,
+                                           self.train_data,
+                                           client_test_data,
+                                           self.conf.device,
+                                           **{"sleep_time": sleep_time[i]})
+                        for i, u in enumerate(self.train_data.users)]
+
+        self.print_("Clients in total: {}".format(len(self.clients)))
+
+    def init_client(self):
+        """Initialize client instance.
+
+        Returns:
+            :obj:`BaseClient`: The initialized client instance.
+        """
+        if not self.registered_client:
+            self._client_class = BaseClient
+
+        # Get a random client if not specified
+        if self.conf.index:
+            user = self.train_data.users[self.conf.index]
+        else:
+            user = random.choice(self.train_data.users)
+
+        return self._client_class(user,
+                                  self.conf.client,
+                                  self.train_data,
+                                  self.test_data,
+                                  self.conf.device,
+                                  is_remote=self.conf.is_remote,
+                                  local_port=self.conf.local_port,
+                                  server_addr=self.conf.server_addr,
+                                  tracker_addr=self.conf.tracker_addr)
+
+    def start_server(self, args):
+        """Start a server service for remote training.
+
+        Server controls the model and testing dataset if configured to test in server.
+
+        Args:
+            args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
+        """
+        if args:
+            self.conf = OmegaConf.merge(self.conf, args.__dict__)
+
+        if self.conf.test_mode == TEST_IN_SERVER:
+            self.init_dataset()
+
+        self.init_model()
+
+        self.init_server()
+
+        self.server.start_service()
+
+    def start_client(self, args):
+        """Start a client service for remote training.
+
+        Client controls training datasets.
+
+        Args:
+            args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
+        """
+
+        if args:
+            self.conf = OmegaConf.merge(self.conf, args.__dict__)
+
+        self.init_dataset()
+
+        client = self.init_client()
+
+        client.start_service()
+
+    def register_dataset(self, train_data, test_data, val_data=None):
+        """Register datasets.
+
+        Datasets should inherit from :obj:`FederatedDataset`, e.g., :obj:`FederatedTensorDataset`.
+
+        Args:
+            train_data (:obj:`FederatedDataset`): Training dataset.
+            test_data (:obj:`FederatedDataset`): Testing dataset.
+            val_data (:obj:`FederatedDataset`): Validation dataset.
+        """
+        self.registered_dataset = True
+        self.train_data = train_data
+        self.test_data = test_data
+        self.val_data = val_data
+
+    def register_model(self, model):
+        """Register customized model for federated learning.
+
+        Args:
+            model (nn.Module): PyTorch model, both class and instance are acceptable.
+                Use model class when there is no specific arguments to initialize model.
+        """
+        self.registered_model = True
+        if not isinstance(model, type):
+            self.model = model
+        else:
+            self._model_class = model
+
+    def register_server(self, server):
+        """Register a customized federated learning server.
+
+        Args:
+            server (:obj:`BaseServer`): Customized federated learning server.
+        """
+        self.registered_server = True
+        self._server_class = server
+
+    def register_client(self, client):
+        """Register a customized federated learning client.
+
+        Args:
+            client (:obj:`BaseClient`): Customized federated learning client.
+        """
+        self.registered_client = True
+        self._client_class = client
+
+    def print_(self, content):
+        """Log the content only when the server is primary server.
+
+        Args:
+            content (str): The content to log.
+        """
+        if self._is_primary_server():
+            logger.info(content)
+
+    def _is_primary_server(self):
+        """Check whether current running server is the primary server.
+
+        In standalone or remote training, the server is primary.
+        In distributed training, the server on `rank0` is primary.
+        """
+        return not self.conf.is_distributed or self.conf.distributed.rank == 0
+
+
+def _set_random_seed(seed):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
+
+# Initialize the global coordinator object
+_global_coord = Coordinator()
+
+
+def init_conf(conf=None):
+    """Initialize configuration for EasyFL. It overrides and supplements default configuration loaded from config.yaml
+    with the provided configurations.
+
+    Args:
+        conf (dict): Configurations.
+
+    Returns:
+        omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf.
+    """
+    here = path.abspath(path.dirname(__file__))
+    config_file = path.join(here, 'config.yaml')
+    return load_config(config_file, conf)
+
+
+def load_config(file, conf=None):
+    """Load and merge configuration from file and input
+
+    Args:
+        file (str): filename of the configuration.
+        conf (dict): Configurations.
+
+    Returns:
+        omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf.
+    """
+    config = OmegaConf.load(file)
+    if conf is not None:
+        config = OmegaConf.merge(config, conf)
+    return config
+
+
+def init_logger(log_level):
+    """Initialize internal logger of EasyFL.
+
+    Args:
+        log_level (int): Logger level, e.g., logging.INFO, logging.DEBUG
+    """
+    log_formatter = logging.Formatter("%(asctime)s [%(threadName)s] [%(levelname)-5.5s]  %(message)s")
+    root_logger = logging.getLogger()
+
+    log_level = logging.INFO if not log_level else log_level
+    root_logger.setLevel(log_level)
+
+    file_path = os.path.join(os.getcwd(), "logs")
+    if not os.path.exists(file_path):
+        os.makedirs(file_path)
+    file_path = path.join(file_path, "train" + time.strftime(".%m_%d_%H_%M_%S") + ".log")
+    file_handler = logging.FileHandler(file_path)
+    file_handler.setFormatter(log_formatter)
+    root_logger.addHandler(file_handler)
+
+    console_handler = logging.StreamHandler(sys.stdout)
+    console_handler.setFormatter(log_formatter)
+    root_logger.addHandler(console_handler)
+
+
+def init(conf=None, init_all=True):
+    """Initialize EasyFL.
+
+    Args:
+        conf (dict, optional): Configurations.
+        init_all (bool, optional): Whether initialize dataset, model, server, and client other than configuration.
+    """
+    global _global_coord
+
+    config = init_conf(conf)
+
+    init_logger(config.tracking.log_level)
+
+    _set_random_seed(config.seed)
+
+    _global_coord.init(config, init_all)
+
+
+def run():
+    """Run federated learning process."""
+    global _global_coord
+    _global_coord.run()
+
+
+def init_dataset():
+    """Initialize dataset, either using registered dataset or out-of-the-box datasets set in config."""
+    global _global_coord
+    _global_coord.init_dataset()
+
+
+def init_model():
+    """Initialize model, either using registered model or out-of–the-box model set in config.
+
+    Returns:
+        nn.Module: Model used in federated learning.
+    """
+    global _global_coord
+    _global_coord.init_model()
+
+    return _global_coord.model
+
+
+def start_server(args=None):
+    """Start federated learning server service for remote training.
+
+    Args:
+        args (argparse.Namespace): Configurations passed in as arguments.
+    """
+    global _global_coord
+
+    _global_coord.start_server(args)
+
+
+def start_client(args=None):
+    """Start federated learning client service for remote training.
+
+    Args:
+        args (argparse.Namespace): Configurations passed in as arguments.
+    """
+    global _global_coord
+
+    _global_coord.start_client(args)
+
+
+def get_coordinator():
+    """Get the global coordinator instance.
+
+    Returns:
+        :obj:`Coordinator`: global coordinator instance.
+    """
+    return _global_coord
+
+
+def register_dataset(train_data, test_data, val_data=None):
+    """Register datasets for federated learning training.
+
+    Args:
+        train_data (:obj:`FederatedDataset`): Training dataset.
+        test_data (:obj:`FederatedDataset`): Testing dataset.
+        val_data (:obj:`FederatedDataset`): Validation dataset.
+    """
+    global _global_coord
+    _global_coord.register_dataset(train_data, test_data, val_data)
+
+
+def register_model(model):
+    """Register model for federated learning training.
+
+    Args:
+        model (nn.Module): PyTorch model, both class and instance are acceptable.
+    """
+    global _global_coord
+    _global_coord.register_model(model)
+
+
+def register_server(server):
+    """Register federated learning server.
+
+    Args:
+        server (:obj:`BaseServer`): Customized federated learning server.
+    """
+    global _global_coord
+    _global_coord.register_server(server)
+
+
+def register_client(client):
+    """Register federated learning client.
+
+    Args:
+        client (:obj:`BaseClient`): Customized federated learning client.
+    """
+    global _global_coord
+    _global_coord.register_client(client)

+ 26 - 0
easyfl/datasets/__init__.py

@@ -0,0 +1,26 @@
+from easyfl.datasets.data import construct_datasets
+from easyfl.datasets.dataset import (
+    FederatedDataset,
+    FederatedImageDataset,
+    FederatedTensorDataset,
+    FederatedTorchDataset,
+    TEST_IN_SERVER,
+    TEST_IN_CLIENT,
+)
+from easyfl.datasets.simulation import (
+    data_simulation,
+    iid,
+    non_iid_dirichlet,
+    non_iid_class,
+    equal_division,
+    quantity_hetero,
+)
+from easyfl.datasets.utils.base_dataset import BaseDataset
+from easyfl.datasets.femnist import Femnist
+from easyfl.datasets.shakespeare import Shakespeare
+from easyfl.datasets.cifar10 import Cifar10
+from easyfl.datasets.cifar100 import Cifar100
+
+__all__ = ['FederatedDataset', 'FederatedImageDataset', 'FederatedTensorDataset', 'FederatedTorchDataset',
+           'construct_datasets', 'data_simulation', 'iid', 'non_iid_dirichlet', 'non_iid_class',
+           'equal_division', 'quantity_hetero', 'BaseDataset', 'Femnist', 'Shakespeare', 'Cifar10', 'Cifar100']

+ 1 - 0
easyfl/datasets/cifar10/__init__.py

@@ -0,0 +1 @@
+from easyfl.datasets.cifar10.cifar10 import Cifar10

+ 88 - 0
easyfl/datasets/cifar10/cifar10.py

@@ -0,0 +1,88 @@
+import logging
+import os
+
+import torchvision
+
+from easyfl.datasets.simulation import data_simulation
+from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR10
+from easyfl.datasets.utils.util import save_dict
+
+logger = logging.getLogger(__name__)
+
+
+class Cifar10(BaseDataset):
+    def __init__(self,
+                 root,
+                 fraction,
+                 split_type,
+                 user,
+                 iid_user_fraction=0.1,
+                 train_test_split=0.9,
+                 minsample=10,
+                 num_class=80,
+                 num_of_client=100,
+                 class_per_client=2,
+                 setting_folder=None,
+                 seed=-1,
+                 weights=None,
+                 alpha=0.5):
+        super(Cifar10, self).__init__(root,
+                                      CIFAR10,
+                                      fraction,
+                                      split_type,
+                                      user,
+                                      iid_user_fraction,
+                                      train_test_split,
+                                      minsample,
+                                      num_class,
+                                      num_of_client,
+                                      class_per_client,
+                                      setting_folder,
+                                      seed)
+        self.train_data, self.test_data = {}, {}
+        self.split_type = split_type
+        self.num_of_client = num_of_client
+        self.weights = weights
+        self.alpha = alpha
+        self.min_size = minsample
+        self.class_per_client = class_per_client
+
+    def download_packaged_dataset_and_extract(self, filename):
+        pass
+
+    def download_raw_file_and_extract(self):
+        train_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=True, download=True)
+        test_set = torchvision.datasets.CIFAR10(root=self.base_folder, train=False, download=True)
+
+        self.train_data = {
+            'x': train_set.data,
+            'y': train_set.targets
+        }
+
+        self.test_data = {
+            'x': test_set.data,
+            'y': test_set.targets
+        }
+
+    def preprocess(self):
+        train_data_path = os.path.join(self.data_folder, "train")
+        test_data_path = os.path.join(self.data_folder, "test")
+        if not os.path.exists(self.data_folder):
+            os.makedirs(self.data_folder)
+        if self.weights is None and os.path.exists(train_data_path):
+            return
+        logger.info("Start CIFAR10 data simulation")
+        _, train_data = data_simulation(self.train_data['x'],
+                                        self.train_data['y'],
+                                        self.num_of_client,
+                                        self.split_type,
+                                        self.weights,
+                                        self.alpha,
+                                        self.min_size,
+                                        self.class_per_client)
+        logger.info("Complete CIFAR10 data simulation")
+        save_dict(train_data, train_data_path)
+        save_dict(self.test_data, test_data_path)
+
+    def convert_data_to_json(self):
+        pass

+ 1 - 0
easyfl/datasets/cifar100/__init__.py

@@ -0,0 +1 @@
+from easyfl.datasets.cifar100.cifar100 import Cifar100

+ 88 - 0
easyfl/datasets/cifar100/cifar100.py

@@ -0,0 +1,88 @@
+import logging
+import os
+
+import torchvision
+
+from easyfl.datasets.simulation import data_simulation
+from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR100
+from easyfl.datasets.utils.util import save_dict
+
+logger = logging.getLogger(__name__)
+
+
+class Cifar100(BaseDataset):
+    def __init__(self,
+                 root,
+                 fraction,
+                 split_type,
+                 user,
+                 iid_user_fraction=0.1,
+                 train_test_split=0.9,
+                 minsample=10,
+                 num_class=80,
+                 num_of_client=100,
+                 class_per_client=2,
+                 setting_folder=None,
+                 seed=-1,
+                 weights=None,
+                 alpha=0.5):
+        super(Cifar100, self).__init__(root,
+                                       CIFAR100,
+                                       fraction,
+                                       split_type,
+                                       user,
+                                       iid_user_fraction,
+                                       train_test_split,
+                                       minsample,
+                                       num_class,
+                                       num_of_client,
+                                       class_per_client,
+                                       setting_folder,
+                                       seed)
+        self.train_data, self.test_data = {}, {}
+        self.split_type = split_type
+        self.num_of_client = num_of_client
+        self.weights = weights
+        self.alpha = alpha
+        self.min_size = minsample
+        self.class_per_client = class_per_client
+
+    def download_packaged_dataset_and_extract(self, filename):
+        pass
+
+    def download_raw_file_and_extract(self):
+        train_set = torchvision.datasets.CIFAR100(root=self.base_folder, train=True, download=True)
+        test_set = torchvision.datasets.CIFAR100(root=self.base_folder, train=False, download=True)
+
+        self.train_data = {
+            'x': train_set.data,
+            'y': train_set.targets
+        }
+
+        self.test_data = {
+            'x': test_set.data,
+            'y': test_set.targets
+        }
+
+    def preprocess(self):
+        train_data_path = os.path.join(self.data_folder, "train")
+        test_data_path = os.path.join(self.data_folder, "test")
+        if not os.path.exists(self.data_folder):
+            os.makedirs(self.data_folder)
+        if self.weights is None and os.path.exists(train_data_path):
+            return
+        logger.info("Start CIFAR10 data simulation")
+        _, train_data = data_simulation(self.train_data['x'],
+                                        self.train_data['y'],
+                                        self.num_of_client,
+                                        self.split_type,
+                                        self.weights,
+                                        self.alpha,
+                                        self.min_size,
+                                        self.class_per_client)
+        logger.info("Complete CIFAR10 data simulation")
+        save_dict(train_data, train_data_path)
+        save_dict(self.test_data, test_data_path)
+
+    def convert_data_to_json(self):
+        pass

+ 243 - 0
easyfl/datasets/data.py

@@ -0,0 +1,243 @@
+import importlib
+import json
+import logging
+import os
+
+from easyfl.datasets.dataset import FederatedTensorDataset
+from easyfl.datasets.utils.base_dataset import BaseDataset, CIFAR10, CIFAR100
+from easyfl.datasets.utils.util import load_dict
+
+logger = logging.getLogger(__name__)
+
+
+def read_dir(data_dir):
+    clients = []
+    groups = []
+    data = {}
+
+    files = os.listdir(data_dir)
+    files = [f for f in files if f.endswith('.json')]
+    for f in files:
+        file_path = os.path.join(data_dir, f)
+        with open(file_path, 'r') as inf:
+            cdata = json.load(inf)
+        clients.extend(cdata['users'])
+        if 'hierarchies' in cdata:
+            groups.extend(cdata['hierarchies'])
+        data.update(cdata['user_data'])
+
+    clients = list(sorted(data.keys()))
+    return clients, groups, data
+
+
+def read_data(dataset_name, train_data_dir, test_data_dir):
+    """Load datasets from data directories.
+
+    Args:
+        dataset_name (str): The name of the dataset.
+        train_data_dir (str): The directory of training data.
+        test_data_dir (str): The directory of testing data.
+
+    Returns:
+        list[str]: A list of client ids.
+        list[str]: A list of group ids for dataset with hierarchies.
+        dict: A dictionary of training data, e.g., {"id1": {"x": data, "y": label}, "id2": {"x": data, "y": label}}.
+        dict: A dictionary of testing data. The format is same as training data for FEMNIST and Shakespeare datasets.
+            For CIFAR datasets, the format is {"x": data, "y": label}, for centralized testing in the server.
+    """
+    if dataset_name == CIFAR10 or dataset_name == CIFAR100:
+        train_data = load_dict(train_data_dir)
+        test_data = load_dict(test_data_dir)
+        return [], [], train_data, test_data
+
+    # Data in the directories are `json` files with keys `users` and `user_data`.
+    train_clients, train_groups, train_data = read_dir(train_data_dir)
+    test_clients, test_groups, test_data = read_dir(test_data_dir)
+
+    assert train_clients == test_clients
+    assert train_groups == test_groups
+
+    return train_clients, train_groups, train_data, test_data
+
+
+def load_data(root,
+              dataset_name,
+              num_of_clients,
+              split_type,
+              min_size,
+              class_per_client,
+              data_amount,
+              iid_fraction,
+              user,
+              train_test_split,
+              quantity_weights,
+              alpha):
+    """Simulate and load federated datasets.
+
+    Args:
+        root (str): The root directory where datasets stored.
+        dataset_name (str): The name of the dataset. It currently supports: femnist, shakespeare, cifar10, and cifar100.
+            Among them, femnist and shakespeare are adopted from LEAF benchmark.
+        num_of_clients (int): The targeted number of clients to construct.
+        split_type (str): The type of statistical simulation, options: iid, dir, and class.
+            `iid` means independent and identically distributed data.
+            `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
+            `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+            `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+        min_size (int): The minimal number of samples in each client.
+            It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+        class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
+        data_amount (float): The fraction of data sampled for LEAF datasets.
+            e.g., 10% means that only 10% of total dataset size are used.
+        iid_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
+        user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
+            Only applicable to LEAF datasets.
+            True means partitioning users of the dataset into train-test groups.
+            False means partitioning each users' samples into train-test groups.
+        train_test_split (float): The fraction of data for training; the rest are for testing.
+            e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
+        quantity_weights (list[float]): The targeted distribution of quantities to simulate data quantity heterogeneity.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+            The `num_of_clients` should be divisible by `len(weights)`.
+            None means clients are simulated with the same data quantity.
+        alpha (float): The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir`.
+
+    Returns:
+        dict: A dictionary of training data, e.g., {"id1": {"x": data, "y": label}, "id2": {"x": data, "y": label}}.
+        dict: A dictionary of testing data.
+        function: A function to preprocess training data.
+        function: A function to preprocess testing data.
+        torchvision.transforms.transforms.Compose: Training data transformation.
+        torchvision.transforms.transforms.Compose: Testing data transformation.
+    """
+    user_str = "user" if user else "sample"
+    setting = BaseDataset.get_setting_folder(dataset_name, split_type, num_of_clients, min_size, class_per_client,
+                                             data_amount, iid_fraction, user_str, train_test_split, alpha,
+                                             quantity_weights)
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    dataset_file = os.path.join(dir_path, "data_process", "{}.py".format(dataset_name))
+    if not os.path.exists(dataset_file):
+        logger.error("Please specify a valid process file path for process_x and process_y functions.")
+    dataset_path = "easyfl.datasets.data_process.{}".format(dataset_name)
+    dataset_lib = importlib.import_module(dataset_path)
+    process_x = getattr(dataset_lib, "process_x", None)
+    process_y = getattr(dataset_lib, "process_y", None)
+    transform_train = getattr(dataset_lib, "transform_train", None)
+    transform_test = getattr(dataset_lib, "transform_test", None)
+
+    data_dir = os.path.join(root, dataset_name)
+    if not data_dir:
+        os.makedirs(data_dir)
+    train_data_dir = os.path.join(data_dir, setting, "train")
+    test_data_dir = os.path.join(data_dir, setting, "test")
+
+    if not os.path.exists(train_data_dir) or not os.path.exists(test_data_dir):
+        dataset_class_path = "easyfl.datasets.{}.{}".format(dataset_name, dataset_name)
+        dataset_class_lib = importlib.import_module(dataset_class_path)
+        class_name = dataset_name.capitalize()
+        dataset = getattr(dataset_class_lib, class_name)(root=data_dir,
+                                                         fraction=data_amount,
+                                                         split_type=split_type,
+                                                         user=user,
+                                                         iid_user_fraction=iid_fraction,
+                                                         train_test_split=train_test_split,
+                                                         minsample=min_size,
+                                                         num_of_client=num_of_clients,
+                                                         class_per_client=class_per_client,
+                                                         setting_folder=setting,
+                                                         alpha=alpha,
+                                                         weights=quantity_weights)
+        try:
+            filename = f"{setting}.zip"
+            dataset.download_packaged_dataset_and_extract(filename)
+            logger.info(f"Downloaded packaged dataset {dataset_name}: {filename}")
+        except Exception as e:
+            logger.info(f"Failed to download packaged dataset: {e.args}")
+
+        # CIFAR10 generate data in setup() stage, LEAF related datasets generate data in sampling()
+        if not os.path.exists(train_data_dir):
+            dataset.setup()
+        if not os.path.exists(train_data_dir):
+            dataset.sampling()
+
+    users, train_groups, train_data, test_data = read_data(dataset_name, train_data_dir, test_data_dir)
+    return train_data, test_data, process_x, process_y, transform_train, transform_test
+
+
+def construct_datasets(root,
+                       dataset_name,
+                       num_of_clients,
+                       split_type,
+                       min_size,
+                       class_per_client,
+                       data_amount,
+                       iid_fraction,
+                       user,
+                       train_test_split,
+                       quantity_weights,
+                       alpha):
+    """Construct and load provided federated learning datasets.
+
+    Args:
+        root (str): The root directory where datasets stored.
+        dataset_name (str): The name of the dataset. It currently supports: femnist, shakespeare, cifar10, and cifar100.
+            Among them, femnist and shakespeare are adopted from LEAF benchmark.
+        num_of_clients (int): The targeted number of clients to construct.
+        split_type (str): The type of statistical simulation, options: iid, dir, and class.
+            `iid` means independent and identically distributed data.
+            `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
+            `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+            `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+        min_size (int): The minimal number of samples in each client.
+            It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+        class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
+        data_amount (float): The fraction of data sampled for LEAF datasets.
+            e.g., 10% means that only 10% of total dataset size are used.
+        iid_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
+        user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
+            Only applicable to LEAF datasets.
+            True means partitioning users of the dataset into train-test groups.
+            False means partitioning each users' samples into train-test groups.
+        train_test_split (float): The fraction of data for training; the rest are for testing.
+            e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
+        quantity_weights (list[float]): The targeted distribution of quantities to simulate data quantity heterogeneity.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+            The `num_of_clients` should be divisible by `len(weights)`.
+            None means clients are simulated with the same data quantity.
+        alpha (float): The parameter for Dirichlet distribution simulation, applicable only when split_type is `dir`.
+
+    Returns:
+        :obj:`FederatedDataset`: Training dataset.
+        :obj:`FederatedDataset`: Testing dataset.
+    """
+    train_data, test_data, process_x, process_y, transform_train, transform_test = load_data(root,
+                                                                                             dataset_name,
+                                                                                             num_of_clients,
+                                                                                             split_type,
+                                                                                             min_size,
+                                                                                             class_per_client,
+                                                                                             data_amount,
+                                                                                             iid_fraction,
+                                                                                             user,
+                                                                                             train_test_split,
+                                                                                             quantity_weights,
+                                                                                             alpha)
+
+    # CIFAR datasets are simulated.
+    test_simulated = True
+    if dataset_name == CIFAR10 or dataset_name == CIFAR100:
+        test_simulated = False
+
+    train_data = FederatedTensorDataset(train_data,
+                                        simulated=True,
+                                        do_simulate=False,
+                                        process_x=process_x,
+                                        process_y=process_y,
+                                        transform=transform_train)
+    test_data = FederatedTensorDataset(test_data,
+                                       simulated=test_simulated,
+                                       do_simulate=False,
+                                       process_x=process_x,
+                                       process_y=process_y,
+                                       transform=transform_test)
+    return train_data, test_data

+ 0 - 0
easyfl/datasets/data_process/__init__.py


+ 55 - 0
easyfl/datasets/data_process/cifar10.py

@@ -0,0 +1,55 @@
+import numpy as np
+import torch
+import torchvision
+from torchvision import transforms
+
+
+class Cutout(object):
+    """Cutout data augmentation is adopted from https://github.com/uoguelph-mlrg/Cutout"""
+
+    def __init__(self, length=16):
+        self.length = length
+
+    def __call__(self, img):
+        """
+        Args:
+            img (Tensor): Tensor image of size (C, H, W).
+
+        Returns:
+            Tensor: Image with n_holes of dimension length x length cut out of it. 
+        """
+        h = img.size(1)
+        w = img.size(2)
+
+        mask = np.ones((h, w), np.float32)
+
+        y = np.random.randint(h)
+        x = np.random.randint(w)
+
+        y1 = np.clip(y - self.length // 2, 0, h)
+        y2 = np.clip(y + self.length // 2, 0, h)
+        x1 = np.clip(x - self.length // 2, 0, w)
+        x2 = np.clip(x + self.length // 2, 0, w)
+
+        mask[y1: y2, x1: x2] = 0.
+
+        mask = torch.from_numpy(mask)
+        mask = mask.expand_as(img)
+        img *= mask
+        return img
+
+
+transform_train = transforms.Compose([
+    torchvision.transforms.ToPILImage(mode='RGB'),
+    transforms.RandomCrop(32, padding=4),
+    transforms.RandomHorizontalFlip(),
+    transforms.ToTensor(),
+    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
+])
+
+transform_train.transforms.append(Cutout())
+
+transform_test = transforms.Compose([
+    transforms.ToTensor(),
+    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
+])

+ 55 - 0
easyfl/datasets/data_process/cifar100.py

@@ -0,0 +1,55 @@
+import numpy as np
+import torch
+import torchvision
+from torchvision import transforms
+
+
+class Cutout(object):
+    """Cutout data augmentation is adopted from https://github.com/uoguelph-mlrg/Cutout"""
+
+    def __init__(self, length=16):
+        self.length = length
+
+    def __call__(self, img):
+        """
+        Args:
+            img (Tensor): Tensor image of size (C, H, W).
+
+        Returns:
+            Tensor: Image with n_holes of dimension length x length cut out of it.
+        """
+        h = img.size(1)
+        w = img.size(2)
+
+        mask = np.ones((h, w), np.float32)
+
+        y = np.random.randint(h)
+        x = np.random.randint(w)
+
+        y1 = np.clip(y - self.length // 2, 0, h)
+        y2 = np.clip(y + self.length // 2, 0, h)
+        x1 = np.clip(x - self.length // 2, 0, w)
+        x2 = np.clip(x + self.length // 2, 0, w)
+
+        mask[y1: y2, x1: x2] = 0.
+
+        mask = torch.from_numpy(mask)
+        mask = mask.expand_as(img)
+        img *= mask
+        return img
+
+
+transform_train = transforms.Compose([
+    torchvision.transforms.ToPILImage(mode='RGB'),
+    transforms.RandomCrop(32, padding=4),
+    transforms.RandomHorizontalFlip(),
+    transforms.ToTensor(),
+    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
+])
+
+transform_train.transforms.append(Cutout())
+
+transform_test = transforms.Compose([
+    transforms.ToTensor(),
+    transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
+])

+ 10 - 0
easyfl/datasets/data_process/femnist.py

@@ -0,0 +1,10 @@
+import torch
+
+
+def process_x(raw_x_batch):
+    raw_x_batch = torch.FloatTensor(raw_x_batch)
+    return raw_x_batch.view(-1, 1, 28, 28)
+
+
+def process_y(raw_y_batch):
+    return torch.LongTensor(raw_y_batch)

+ 142 - 0
easyfl/datasets/data_process/language_utils.py

@@ -0,0 +1,142 @@
+"""
+These codes are adopted from LEAF.
+"""
+
+import json
+import re
+
+import numpy as np
+
+# ------------------------
+# utils for shakespeare dataset
+
+ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
+NUM_LETTERS = len(ALL_LETTERS)
+
+
+def _one_hot(index, size):
+    """returns one-hot vector with given size and value 1 at given index"""
+    vec = [0 for _ in range(size)]
+    vec[int(index)] = 1
+    return vec
+
+
+def letter_to_vec(letter):
+    """returns one-hot representation of given letter"""
+    index = ALL_LETTERS.find(letter)
+    return _one_hot(index, NUM_LETTERS)
+
+
+def word_to_indices(word):
+    """returns a list of character indices
+
+    Args:
+        word: string
+    
+    Return:
+        indices: int list with length len(word)
+    """
+    indices = []
+    for c in word:
+        indices.append(ALL_LETTERS.find(c))
+    return indices
+
+
+# ------------------------
+# utils for sent140 dataset
+
+
+def split_line(line):
+    """split given line/phrase into list of words
+
+    Args:
+        line: string representing phrase to be split
+    
+    Return:
+        list of strings, with each string representing a word
+    """
+    return re.findall(r"[\w']+|[.,!?;]", line)
+
+
+def _word_to_index(word, indd):
+    """returns index of given word based on given lookup dictionary
+
+    returns the length of the lookup dictionary if word not found
+
+    Args:
+        word: string
+        indd: dictionary with string words as keys and int indices as values
+    """
+    if word in indd:
+        return indd[word]
+    else:
+        return len(indd)
+
+
+def line_to_indices(line, word2id, max_words=25):
+    """converts given phrase into list of word indices
+    
+    if the phrase has more than max_words words, returns a list containing
+    indices of the first max_words words
+    if the phrase has less than max_words words, repeatedly appends integer 
+    representing unknown index to returned list until the list's length is 
+    max_words
+
+    Args:
+        line: string representing phrase/sequence of words
+        word2id: dictionary with string words as keys and int indices as values
+        max_words: maximum number of word indices in returned list
+
+    Return:
+        indl: list of word indices, one index for each word in phrase
+    """
+    unk_id = len(word2id)
+    line_list = split_line(line)  # split phrase in words
+    indl = [word2id[w] if w in word2id else unk_id for w in line_list[:max_words]]
+    indl += [unk_id] * (max_words - len(indl))
+    return indl
+
+
+def bag_of_words(line, vocab):
+    """returns bag of words representation of given phrase using given vocab
+
+    Args:
+        line: string representing phrase to be parsed
+        vocab: dictionary with words as keys and indices as values
+
+    Return:
+        integer list
+    """
+    bag = [0] * len(vocab)
+    words = split_line(line)
+    for w in words:
+        if w in vocab:
+            bag[vocab[w]] += 1
+    return bag
+
+
+def get_word_emb_arr(path):
+    with open(path, 'r') as inf:
+        embs = json.load(inf)
+    vocab = embs['vocab']
+    word_emb_arr = np.array(embs['emba'])
+    indd = {}
+    for i in range(len(vocab)):
+        indd[vocab[i]] = i
+    vocab = {w: i for i, w in enumerate(embs['vocab'])}
+    return word_emb_arr, indd, vocab
+
+
+def val_to_vec(size, val):
+    """Converts target into one-hot.
+
+    Args:
+        size: Size of vector.
+        val: Integer in range [0, size].
+    Returns:
+         vec: one-hot vector with a 1 in the val element.
+    """
+    assert 0 <= val < size
+    vec = [0 for _ in range(size)]
+    vec[int(val)] = 1
+    return vec

+ 15 - 0
easyfl/datasets/data_process/shakespeare.py

@@ -0,0 +1,15 @@
+import numpy as np
+import torch
+
+from easyfl.datasets.data_process.language_utils import word_to_indices, letter_to_vec
+
+
+def process_x(raw_x_batch):
+    x_batch = [word_to_indices(word) for word in raw_x_batch]
+    x_batch = np.array(x_batch)
+    return torch.LongTensor(x_batch)
+
+
+def process_y(raw_y_batch):
+    y_batch = [np.argmax(letter_to_vec(c)) for c in raw_y_batch]
+    return torch.LongTensor(y_batch)

+ 427 - 0
easyfl/datasets/dataset.py

@@ -0,0 +1,427 @@
+import logging
+import os
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch
+from torch.utils.data import TensorDataset, DataLoader
+from torchvision.datasets.folder import default_loader, make_dataset
+
+from easyfl.datasets.dataset_util import TransformDataset, ImageDataset
+from easyfl.datasets.simulation import data_simulation, SIMULATE_IID
+
+logger = logging.getLogger(__name__)
+
+TEST_IN_SERVER = "test_in_server"
+TEST_IN_CLIENT = "test_in_client"
+
+IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+DEFAULT_MERGED_ID = "Merged"
+
+
+def default_process_x(raw_x_batch):
+    return torch.tensor(raw_x_batch)
+
+
+def default_process_y(raw_y_batch):
+    return torch.tensor(raw_y_batch)
+
+
+class FederatedDataset(ABC):
+    """The abstract class of federated dataset for EasyFL."""
+
+    def __init__(self):
+        pass
+
+    @abstractmethod
+    def loader(self, batch_size, shuffle=True):
+        """Get data loader.
+
+        Args:
+            batch_size (int): The batch size of the data loader.
+            shuffle (bool): Whether shuffle the data in the loader.
+        """
+        raise NotImplementedError("Data loader not implemented")
+
+    @abstractmethod
+    def size(self, cid):
+        """Get dataset size.
+
+        Args:
+            cid (str): client id.
+        """
+        raise NotImplementedError("Size not implemented")
+
+    @property
+    def users(self):
+        """Get client ids of the federated dataset."""
+        raise NotImplementedError("Users not implemented")
+
+
+class FederatedTensorDataset(FederatedDataset):
+    """Federated tensor dataset, data of clients are in format of tensor or list.
+
+    Args:
+        data (dict): A dictionary of data, e.g., {"id1": {"x": [[], [], ...], "y": [...]]}}.
+            If simulation is not done previously, it is in format of {'x':[[],[], ...], 'y': [...]}.
+        transform (torchvision.transforms.transforms.Compose, optional): Transformation for data.
+        target_transform (torchvision.transforms.transforms.Compose, optional): Transformation for data labels.
+        process_x (function, optional): A function to preprocess training data.
+        process_y (function, optional): A function to preprocess testing data.
+        simulated (bool, optional): Whether the dataset is simulated to federated learning settings.
+        do_simulate (bool, optional): Whether conduct simulation. It is only effective if it is not simulated.
+        num_of_clients (int, optional): number of clients for simulation. Only need if doing simulation.
+        simulation_method(optional): split method. Only need if doing simulation.
+        weights (list[float], optional): The targeted distribution of quantities to simulate quantity heterogeneity.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+            The `num_of_clients` should be divisible by `len(weights)`.
+            None means clients are simulated with the same data quantity.
+        alpha (float, optional): The parameter for Dirichlet distribution simulation, only for dir simulation.
+        min_size (int, optional): The minimal number of samples in each client, only for dir simulation.
+        class_per_client (int, optional): The number of classes in each client, only for non-iid by class simulation.
+    """
+
+    def __init__(self,
+                 data,
+                 transform=None,
+                 target_transform=None,
+                 process_x=default_process_x,
+                 process_y=default_process_x,
+                 simulated=False,
+                 do_simulate=True,
+                 num_of_clients=10,
+                 simulation_method=SIMULATE_IID,
+                 weights=None,
+                 alpha=0.5,
+                 min_size=10,
+                 class_per_client=1):
+        super(FederatedTensorDataset, self).__init__()
+        self.simulated = simulated
+        self.data = data
+        self._validate_data(self.data)
+        self.process_x = process_x
+        self.process_y = process_y
+        self.transform = transform
+        self.target_transform = target_transform
+        if simulated:
+            self._users = sorted(list(self.data.keys()))
+
+        elif do_simulate:
+            # For simulation method provided, we support testing in server for now
+            # TODO: support simulation for test data => test in clients
+            self.simulation(num_of_clients, simulation_method, weights, alpha, min_size, class_per_client)
+
+    def simulation(self, num_of_clients, niid=SIMULATE_IID, weights=None, alpha=0.5, min_size=10, class_per_client=1):
+        if self.simulated:
+            logger.warning("The dataset is already simulated, the simulation would not proceed.")
+            return
+        self._users, self.data = data_simulation(
+            self.data['x'],
+            self.data['y'],
+            num_of_clients,
+            niid,
+            weights,
+            alpha,
+            min_size,
+            class_per_client)
+        self.simulated = True
+
+    def loader(self, batch_size, client_id=None, shuffle=True, seed=0, transform=None, drop_last=False):
+        """Get dataset loader.
+
+        Args:
+            batch_size (int): The batch size.
+            client_id (str, optional): The id of client.
+            shuffle (bool, optional): Whether to shuffle before batching.
+            seed (int, optional): The shuffle seed.
+            transform (torchvision.transforms.transforms.Compose, optional): Data transformation.
+            drop_last (bool, optional): Whether to drop the last batch if its size is smaller than batch size.
+
+        Returns:
+            torch.utils.data.DataLoader: The data loader to load data.
+        """
+        # Simulation need to be done before creating a data loader
+        if client_id is None:
+            data = self.data
+        else:
+            data = self.data[client_id]
+
+        data_x = data['x']
+        data_y = data['y']
+
+        data_x = np.array(data_x)
+        data_y = np.array(data_y)
+
+        data_x = self._input_process(data_x)
+        data_y = self._label_process(data_y)
+        if shuffle:
+            np.random.seed(seed)
+            rng_state = np.random.get_state()
+            np.random.shuffle(data_x)
+            np.random.set_state(rng_state)
+            np.random.shuffle(data_y)
+
+        transform = self.transform if transform is None else transform
+        if transform is not None:
+            dataset = TransformDataset(data_x,
+                                       data_y,
+                                       transform_x=transform,
+                                       transform_y=self.target_transform)
+        else:
+            dataset = TensorDataset(data_x, data_y)
+        loader = DataLoader(dataset=dataset,
+                            batch_size=batch_size,
+                            shuffle=shuffle,
+                            drop_last=drop_last)
+        return loader
+
+    @property
+    def users(self):
+        return self._users
+
+    @users.setter
+    def users(self, value):
+        self._users = value
+
+    def size(self, cid=None):
+        if cid is not None:
+            return len(self.data[cid]['y'])
+        else:
+            return len(self.data['y'])
+
+    def total_size(self):
+        if 'y' in self.data:
+            return len(self.data['y'])
+        else:
+            return sum([len(self.data[i]['y']) for i in self.data])
+
+    def _input_process(self, sample):
+        if self.process_x is not None:
+            sample = self.process_x(sample)
+        return sample
+
+    def _label_process(self, label):
+        if self.process_y is not None:
+            label = self.process_y(label)
+        return label
+
+    def _validate_data(self, data):
+        if self.simulated:
+            for i in data:
+                assert len(data[i]['x']) == len(data[i]['y'])
+        else:
+            assert len(data['x']) == len(data['y'])
+
+
+class FederatedImageDataset(FederatedDataset):
+    """
+    Federated image dataset, data of clients are in format of image folder.
+
+    Args:
+        root (str|list[str]): The root directory or directories of image data folder.
+            If the dataset is simulated to multiple clients, the root is a list of directories.
+            Otherwise, it is the directory of an image data folder.
+        simulated (bool): Whether the dataset is simulated to federated learning settings.
+        do_simulate (bool, optional): Whether conduct simulation. It is only effective if it is not simulated.
+        extensions (list[str], optional): A list of allowed image extensions.
+            Only one of `extensions` and `is_valid_file` can be specified.
+        is_valid_file (function, optional): A function that takes path of an Image file and check if it is valid.
+            Only one of `extensions` and `is_valid_file` can be specified.
+        transform (torchvision.transforms.transforms.Compose, optional): Transformation for data.
+        target_transform (torchvision.transforms.transforms.Compose, optional): Transformation for data labels.
+        num_of_clients (int, optional): number of clients for simulation. Only need if doing simulation.
+        simulation_method(optional): split method. Only need if doing simulation.
+        weights (list[float], optional): The targeted distribution of quantities to simulate quantity heterogeneity.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+            The `num_of_clients` should be divisible by `len(weights)`.
+            None means clients are simulated with the same data quantity.
+        alpha (float, optional): The parameter for Dirichlet distribution simulation, only for dir simulation.
+        min_size (int, optional): The minimal number of samples in each client, only for dir simulation.
+        class_per_client (int, optional): The number of classes in each client, only for non-iid by class simulation.
+        client_ids (list[str], optional): A list of client ids.
+            Each client id matches with an element in roots.
+            The client ids are ["f0000001", "f00000002", ...] if not specified.
+    """
+
+    def __init__(self,
+                 root,
+                 simulated,
+                 do_simulate=True,
+                 extensions=IMG_EXTENSIONS,
+                 is_valid_file=None,
+                 transform=None,
+                 target_transform=None,
+                 client_ids="default",
+                 num_of_clients=10,
+                 simulation_method=SIMULATE_IID,
+                 weights=None,
+                 alpha=0.5,
+                 min_size=10,
+                 class_per_client=1):
+        super(FederatedImageDataset, self).__init__()
+        self.simulated = simulated
+        self.transform = transform
+        self.target_transform = target_transform
+
+        if self.simulated:
+            self.data = {}
+            self.classes = {}
+            self.class_to_idx = {}
+            self.roots = root
+            self.num_of_clients = len(self.roots)
+            if client_ids == "default":
+                self.users = ["f%07.0f" % (i) for i in range(len(self.roots))]
+            else:
+                self.users = client_ids
+            for i in range(self.num_of_clients):
+                current_client_id = self.users[i]
+                classes, class_to_idx = self._find_classes(self.roots[i])
+                samples = make_dataset(self.roots[i], class_to_idx, extensions, is_valid_file)
+                if len(samples) == 0:
+                    msg = "Found 0 files in subfolders of: {}\n".format(self.root)
+                    if extensions is not None:
+                        msg += "Supported extensions are: {}".format(",".join(extensions))
+                    raise RuntimeError(msg)
+
+                self.classes[current_client_id] = classes
+                self.class_to_idx[current_client_id] = class_to_idx
+                temp_client = {'x': [i[0] for i in samples], 'y': [i[1] for i in samples]}
+                self.data[current_client_id] = temp_client
+        elif do_simulate:
+            self.root = root
+            classes, class_to_idx = self._find_classes(self.root)
+            samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
+            if len(samples) == 0:
+                msg = "Found 0 files in subfolders of: {}\n".format(self.root)
+                if extensions is not None:
+                    msg += "Supported extensions are: {}".format(",".join(extensions))
+                raise RuntimeError(msg)
+            self.extensions = extensions
+            self.classes = classes
+            self.class_to_idx = class_to_idx
+            self.samples = samples
+            self.inputs = [i[0] for i in self.samples]
+            self.labels = [i[1] for i in self.samples]
+            self.simulation(num_of_clients, simulation_method, weights, alpha, min_size, class_per_client)
+
+    def simulation(self, num_of_clients, niid="iid", weights=[1], alpha=0.5, min_size=10, class_per_client=1):
+        if self.simulated:
+            logger.warning("The dataset is already simulated, the simulation would not proceed.")
+            return
+        self.users, self.data = data_simulation(self.inputs,
+                                                self.labels,
+                                                num_of_clients,
+                                                niid,
+                                                weights,
+                                                alpha,
+                                                min_size,
+                                                class_per_client)
+        self.simulated = True
+
+    def loader(self, batch_size, client_id=None, shuffle=True, seed=0, num_workers=2, transform=None):
+        """Get dataset loader.
+
+        Args:
+            batch_size (int): The batch size.
+            client_id (str, optional): The id of client.
+            shuffle (bool, optional): Whether to shuffle before batching.
+            seed (int, optional): The shuffle seed.
+            transform (torchvision.transforms.transforms.Compose, optional): Data transformation.
+            num_workers (int, optional): The number of workers for dataset loader.
+
+        Returns:
+            torch.utils.data.DataLoader: The data loader to load data.
+        """
+        assert self.simulated is True
+        if client_id is None:
+            data = self.data
+        else:
+            data = self.data[client_id]
+        data_x = data['x'][:]
+        data_y = data['y'][:]
+
+        # randomly shuffle data
+        if shuffle:
+            np.random.seed(seed)
+            rng_state = np.random.get_state()
+            np.random.shuffle(data_x)
+            np.random.set_state(rng_state)
+            np.random.shuffle(data_y)
+
+        transform = self.transform if transform is None else transform
+        dataset = ImageDataset(data_x, data_y, transform, self.target_transform)
+        loader = torch.utils.data.DataLoader(dataset,
+                                             batch_size=batch_size,
+                                             shuffle=shuffle,
+                                             num_workers=num_workers,
+                                             pin_memory=False)
+        return loader
+
+    @property
+    def users(self):
+        return self._users
+
+    @users.setter
+    def users(self, value):
+        self._users = value
+
+    def size(self, cid=None):
+        if cid is not None:
+            return len(self.data[cid]['y'])
+        else:
+            return len(self.data['y'])
+
+    def _find_classes(self, dir):
+        """Get the classes of the dataset.
+
+        Args:
+            dir (str): Root directory path.
+
+        Returns:
+            tuple: (classes, class_to_idx) where classes are relative to directory and class_to_idx is a dictionary.
+
+        Note:
+            Need to ensure that no class is a subdirectory of another.
+        """
+        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
+        classes.sort()
+        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+        return classes, class_to_idx
+
+
+class FederatedTorchDataset(FederatedDataset):
+    """Wrapper over PyTorch dataset.
+
+    Args:
+        data (dict): A dictionary of client datasets, format {"client_id": loader1, "client_id2": loader2}.
+    """
+
+    def __init__(self, data, users):
+        super(FederatedTorchDataset, self).__init__()
+        self.data = data
+        self._users = users
+
+    def loader(self, batch_size, client_id=None, shuffle=True, seed=0, num_workers=2, transform=None):
+        if client_id is None:
+            data = self.data
+        else:
+            data = self.data[client_id]
+
+        loader = torch.utils.data.DataLoader(
+            data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
+        return loader
+
+    @property
+    def users(self):
+        return self._users
+
+    @users.setter
+    def users(self, value):
+        self._users = value
+
+    def size(self, cid=None):
+        if cid is not None:
+            return len(self.data[cid])
+        else:
+            return len(self.data)

+ 45 - 0
easyfl/datasets/dataset_util.py

@@ -0,0 +1,45 @@
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class ImageDataset(Dataset):
+    def __init__(self, images, labels, transform_x=None, transform_y=None):
+        self.images = images
+        self.labels = labels
+        self.transform_x = transform_x
+        self.transform_y = transform_y
+
+    def __len__(self):
+        return len(self.labels)
+
+    def __getitem__(self, index):
+        data, label = self.images[index], self.labels[index]
+        if self.transform_x is not None:
+            data = self.transform_x(Image.open(data))
+        else:
+            data = Image.open(data)
+        if self.transform_y is not None:
+            label = self.transform_y(label)
+        return data, label
+
+
+class TransformDataset(Dataset):
+    def __init__(self, images, labels, transform_x=None, transform_y=None):
+        self.data = images
+        self.targets = labels
+        self.transform_x = transform_x
+        self.transform_y = transform_y
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        sample = self.data[idx]
+        target = self.targets[idx]
+
+        if self.transform_x:
+            sample = self.transform_x(sample)
+        if self.transform_y:
+            target = self.transform_y(target)
+
+        return sample, target

+ 1 - 0
easyfl/datasets/femnist/__init__.py

@@ -0,0 +1 @@
+from easyfl.datasets.femnist.femnist import Femnist

+ 109 - 0
easyfl/datasets/femnist/femnist.py

@@ -0,0 +1,109 @@
+import logging
+import os
+
+from easyfl.datasets.femnist.preprocess.data_to_json import data_to_json
+from easyfl.datasets.femnist.preprocess.get_file_dirs import get_file_dir
+from easyfl.datasets.femnist.preprocess.get_hashes import get_hash
+from easyfl.datasets.femnist.preprocess.group_by_writer import group_by_writer
+from easyfl.datasets.femnist.preprocess.match_hashes import match_hash
+from easyfl.datasets.utils.base_dataset import BaseDataset
+from easyfl.datasets.utils.download import download_url, extract_archive, download_from_google_drive
+
+logger = logging.getLogger(__name__)
+
+
+class Femnist(BaseDataset):
+    """FEMNIST dataset implementation. It gets FEMNIST dataset according to configurations.
+     It stores the processed datasets locally.
+
+    Attributes:
+        base_folder (str): The base folder path of the datasets folder.
+        class_url (str): The url to get the by_class split FEMNIST.
+        write_url (str): The url to get the by_write split FEMNIST.
+    """
+
+    def __init__(self,
+                 root,
+                 fraction,
+                 split_type,
+                 user,
+                 iid_user_fraction=0.1,
+                 train_test_split=0.9,
+                 minsample=10,
+                 num_class=62,
+                 num_of_client=100,
+                 class_per_client=2,
+                 setting_folder=None,
+                 seed=-1,
+                 **kwargs):
+        super(Femnist, self).__init__(root,
+                                      "femnist",
+                                      fraction,
+                                      split_type,
+                                      user,
+                                      iid_user_fraction,
+                                      train_test_split,
+                                      minsample,
+                                      num_class,
+                                      num_of_client,
+                                      class_per_client,
+                                      setting_folder,
+                                      seed)
+        self.class_url = "https://s3.amazonaws.com/nist-srd/SD19/by_class.zip"
+        self.write_url = "https://s3.amazonaws.com/nist-srd/SD19/by_write.zip"
+        self.packaged_data_files = {
+            "femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/oyhegd3c0pxa0tl/femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip",
+            "femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/jcg0xrz5qrri4tv/femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip"
+        }
+        # Google Drive ids
+        # self.packaged_data_files = {
+        #     "femnist_niid_100_10_1_0.05_0.1_sample_0.9.zip": "11vAxASl-af41iHpFqW2jixs1jOUZDXMS",
+        #     "femnist_iid_100_10_1_0.05_0.1_sample_0.9.zip": "1U9Sn2ACbidwhhihdJdZPfK2YddPMr33k"
+        # }
+
+    def download_packaged_dataset_and_extract(self, filename):
+        file_path = download_url(self.packaged_data_files[filename], self.base_folder)
+        extract_archive(file_path, remove_finished=True)
+
+    def download_raw_file_and_extract(self):
+        raw_data_folder = os.path.join(self.base_folder, "raw_data")
+        if not os.path.exists(raw_data_folder):
+            os.makedirs(raw_data_folder)
+        elif os.listdir(raw_data_folder):
+            logger.info("raw file exists")
+            return
+        class_path = download_url(self.class_url, raw_data_folder)
+        write_path = download_url(self.write_url, raw_data_folder)
+        extract_archive(class_path, remove_finished=True)
+        extract_archive(write_path, remove_finished=True)
+        logger.info("raw file is downloaded")
+
+    def preprocess(self):
+        intermediate_folder = os.path.join(self.base_folder, "intermediate")
+        if not os.path.exists(intermediate_folder):
+            os.makedirs(intermediate_folder)
+        if not os.path.exists(intermediate_folder + "/class_file_dirs.pkl"):
+            logger.info("extracting file directories of images")
+            get_file_dir(self.base_folder)
+            logger.info("finished extracting file directories of images")
+        if not os.path.exists(intermediate_folder + "/class_file_hashes.pkl"):
+            logger.info("calculating image hashes")
+            get_hash(self.base_folder)
+            logger.info("finished calculating image hashes")
+        if not os.path.exists(intermediate_folder + "/write_with_class.pkl"):
+            logger.info("assigning class labels to write images")
+            match_hash(self.base_folder)
+            logger.info("finished assigning class labels to write images")
+        if not os.path.exists(intermediate_folder + "/images_by_writer.pkl"):
+            logger.info("grouping images by writer")
+            group_by_writer(self.base_folder)
+            logger.info("finished grouping images by writer")
+
+    def convert_data_to_json(self):
+        all_data_folder = os.path.join(self.base_folder, "all_data")
+        if not os.path.exists(all_data_folder):
+            os.makedirs(all_data_folder)
+        if not os.listdir(all_data_folder):
+            logger.info("converting data to .json format")
+            data_to_json(self.base_folder)
+            logger.info("finished converting data to .json format")

+ 0 - 0
easyfl/datasets/femnist/preprocess/__init__.py


+ 94 - 0
easyfl/datasets/femnist/preprocess/data_to_json.py

@@ -0,0 +1,94 @@
+"""
+These codes are adopted from LEAF with some modifications.
+
+It converts a list of (writer, [list of (file,class)]) tuples into a json object of the form:
+  {users: [bob, etc], num_samples: [124, etc.],
+  user_data: {bob : {x:[img1,img2,etc], y:[class1,class2,etc]}, etc}},
+where "img_" is a vectorized representation of the corresponding image.
+"""
+
+from __future__ import division
+
+import json
+import math
+import os
+
+import numpy as np
+from PIL import Image
+
+from easyfl.datasets.utils import util
+
+MAX_WRITERS = 100  # max number of writers per json file.
+
+
+def relabel_class(c):
+    """
+    maps hexadecimal class value (string) to a decimal number
+    returns:
+    - 0 through 9 for classes representing respective numbers
+    - 10 through 35 for classes representing respective uppercase letters
+    - 36 through 61 for classes representing respective lowercase letters
+    """
+    if c.isdigit() and int(c) < 40:
+        return int(c) - 30
+    elif int(c, 16) <= 90:  # uppercase
+        return int(c, 16) - 55
+    else:
+        return int(c, 16) - 61
+
+
+def data_to_json(base_folder):
+    by_writer_dir = os.path.join(base_folder, "intermediate", "images_by_writer")
+    writers = util.load_obj(by_writer_dir)
+
+    num_json = int(math.ceil(len(writers) / MAX_WRITERS))
+
+    users = []
+    num_samples = []
+    user_data = {}
+
+    writer_count = 0
+    json_index = 0
+    for (w, l) in writers:
+
+        users.append(w)
+        num_samples.append(len(l))
+        user_data[w] = {"x": [], "y": []}
+
+        size = 28, 28  # original image size is 128, 128
+        for (f, c) in l:
+            file_path = os.path.join(base_folder, f)
+            img = Image.open(file_path)
+            gray = img.convert("L")
+            gray.thumbnail(size, Image.ANTIALIAS)
+            arr = np.asarray(gray).copy()
+            vec = arr.flatten()
+            vec = vec / 255  # scale all pixel values to between 0 and 1
+            vec = vec.tolist()
+
+            nc = relabel_class(c)
+
+            user_data[w]["x"].append(vec)
+            user_data[w]["y"].append(nc)
+
+        writer_count += 1
+        if writer_count == MAX_WRITERS:
+            all_data = {}
+            all_data["users"] = users
+            all_data["num_samples"] = num_samples
+            all_data["user_data"] = user_data
+
+            file_name = "all_data_%d.json" % json_index
+            file_path = os.path.join(base_folder, "all_data", file_name)
+
+            print("writing %s" % file_name)
+
+            with open(file_path, "w") as outfile:
+                json.dump(all_data, outfile)
+
+            writer_count = 0
+            json_index += 1
+
+            users[:] = []
+            num_samples[:] = []
+            user_data.clear()

+ 71 - 0
easyfl/datasets/femnist/preprocess/get_file_dirs.py

@@ -0,0 +1,71 @@
+"""
+These codes are adopted from LEAF with some modifications.
+
+Creates .pkl files for:
+1. list of directories of every image in 'by_class'
+2. list of directories of every image in 'by_write'
+the hierarchal structure of the data is as follows:
+- by_class -> classes -> folders containing images -> images
+- by_write -> folders containing writers -> writer -> types of images -> images
+the directories written into the files are of the form 'raw_data/...'
+"""
+
+import os
+
+from easyfl.datasets.utils import util
+
+
+def get_file_dir(base_folder):
+    class_files = []  # (class, file directory)
+    write_files = []  # (writer, file directory)
+
+    class_dir = os.path.join(base_folder, "raw_data", "by_class")
+    rel_class_dir = os.path.join(base_folder, "raw_data", "by_class")
+    classes = os.listdir(class_dir)
+    classes = [c for c in classes if len(c) == 2]
+
+    for cl in classes:
+        cldir = os.path.join(class_dir, cl)
+        rel_cldir = os.path.join(rel_class_dir, cl)
+        subcls = os.listdir(cldir)
+
+        subcls = [s for s in subcls if (("hsf" in s) and ("mit" not in s))]
+
+        for subcl in subcls:
+            subcldir = os.path.join(cldir, subcl)
+            rel_subcldir = os.path.join(rel_cldir, subcl)
+            images = os.listdir(subcldir)
+            image_dirs = [os.path.join(rel_subcldir, i) for i in images]
+
+            for image_dir in image_dirs:
+                class_files.append((cl, image_dir))
+
+    write_dir = os.path.join(base_folder, "raw_data", "by_write")
+    rel_write_dir = os.path.join(base_folder, "raw_data", "by_write")
+    write_parts = os.listdir(write_dir)
+
+    for write_part in write_parts:
+        writers_dir = os.path.join(write_dir, write_part)
+        rel_writers_dir = os.path.join(rel_write_dir, write_part)
+        writers = os.listdir(writers_dir)
+
+        for writer in writers:
+            writer_dir = os.path.join(writers_dir, writer)
+            rel_writer_dir = os.path.join(rel_writers_dir, writer)
+            wtypes = os.listdir(writer_dir)
+
+            for wtype in wtypes:
+                type_dir = os.path.join(writer_dir, wtype)
+                rel_type_dir = os.path.join(rel_writer_dir, wtype)
+                images = os.listdir(type_dir)
+                image_dirs = [os.path.join(rel_type_dir, i) for i in images]
+
+                for image_dir in image_dirs:
+                    write_files.append((writer, image_dir))
+
+    util.save_obj(
+        class_files,
+        os.path.join(base_folder, "intermediate", "class_file_dirs"))
+    util.save_obj(
+        write_files,
+        os.path.join(base_folder, "intermediate", "write_file_dirs"))

+ 55 - 0
easyfl/datasets/femnist/preprocess/get_hashes.py

@@ -0,0 +1,55 @@
+"""
+These codes are adopted from LEAF with some modifications.
+"""
+
+import hashlib
+import logging
+import os
+
+from easyfl.datasets.utils import util
+
+logger = logging.getLogger(__name__)
+
+
+def get_hash(base_folder):
+    cfd = os.path.join(base_folder, "intermediate", "class_file_dirs")
+    wfd = os.path.join(base_folder, "intermediate", "write_file_dirs")
+    class_file_dirs = util.load_obj(cfd)
+    write_file_dirs = util.load_obj(wfd)
+
+    class_file_hashes = []
+    write_file_hashes = []
+
+    count = 0
+    for tup in class_file_dirs:
+        if (count % 100000 == 0):
+            logger.info("hashed %d class images" % count)
+
+        (cclass, cfile) = tup
+        file_path = os.path.join(base_folder, cfile)
+
+        chash = hashlib.md5(open(file_path, "rb").read()).hexdigest()
+
+        class_file_hashes.append((cclass, cfile, chash))
+
+        count += 1
+
+    cfhd = os.path.join(base_folder, "intermediate", "class_file_hashes")
+    util.save_obj(class_file_hashes, cfhd)
+
+    count = 0
+    for tup in write_file_dirs:
+        if (count % 100000 == 0):
+            logger.info("hashed %d write images" % count)
+
+        (cclass, cfile) = tup
+        file_path = os.path.join(base_folder, cfile)
+
+        chash = hashlib.md5(open(file_path, "rb").read()).hexdigest()
+
+        write_file_hashes.append((cclass, cfile, chash))
+
+        count += 1
+
+    wfhd = os.path.join(base_folder, "intermediate", "write_file_hashes")
+    util.save_obj(write_file_hashes, wfhd)

+ 25 - 0
easyfl/datasets/femnist/preprocess/group_by_writer.py

@@ -0,0 +1,25 @@
+"""
+These codes are adopted from LEAF with some modifications.
+"""
+import os
+
+from easyfl.datasets.utils import util
+
+
+def group_by_writer(base_folder):
+    wwcd = os.path.join(base_folder, "intermediate", "write_with_class")
+    write_class = util.load_obj(wwcd)
+
+    writers = []  # each entry is a (writer, [list of (file, class)]) tuple
+    cimages = []
+    (cw, _, _) = write_class[0]
+    for (w, f, c) in write_class:
+        if w != cw:
+            writers.append((cw, cimages))
+            cw = w
+            cimages = [(f, c)]
+        cimages.append((f, c))
+    writers.append((cw, cimages))
+
+    ibwd = os.path.join(base_folder, "intermediate", "images_by_writer")
+    util.save_obj(writers, ibwd)

+ 25 - 0
easyfl/datasets/femnist/preprocess/match_hashes.py

@@ -0,0 +1,25 @@
+"""
+These codes are adopted from LEAF with some modifications.
+"""
+import os
+
+from easyfl.datasets.utils import util
+
+
+def match_hash(base_folder):
+    cfhd = os.path.join(base_folder, "intermediate", "class_file_hashes")
+    wfhd = os.path.join(base_folder, "intermediate", "write_file_hashes")
+    class_file_hashes = util.load_obj(cfhd)
+    write_file_hashes = util.load_obj(wfhd)
+    class_hash_dict = {}
+    for i in range(len(class_file_hashes)):
+        (c, f, h) = class_file_hashes[len(class_file_hashes) - i - 1]
+        class_hash_dict[h] = (c, f)
+
+    write_classes = []
+    for tup in write_file_hashes:
+        (w, f, h) = tup
+        write_classes.append((w, f, class_hash_dict[h][0]))
+
+    wwcd = os.path.join(base_folder, "intermediate", "write_with_class")
+    util.save_obj(write_classes, wwcd)

+ 1 - 0
easyfl/datasets/shakespeare/__init__.py

@@ -0,0 +1 @@
+from easyfl.datasets.shakespeare.shakespeare import Shakespeare

+ 89 - 0
easyfl/datasets/shakespeare/shakespeare.py

@@ -0,0 +1,89 @@
+import logging
+import os
+
+from easyfl.datasets.shakespeare.utils.gen_all_data import generated_all_data
+from easyfl.datasets.shakespeare.utils.preprocess_shakespeare import shakespeare_preprocess
+from easyfl.datasets.utils.base_dataset import BaseDataset
+from easyfl.datasets.utils.download import download_url, extract_archive, download_from_google_drive
+
+logger = logging.getLogger(__name__)
+
+
+class Shakespeare(BaseDataset):
+    """Shakespeare dataset implementation. It gets Shakespeare dataset according to configurations.
+
+    Attributes:
+        base_folder (str): The base folder path of the datasets folder.
+        raw_data_url (str): The url to get the `by_class` split shakespeare.
+        write_url (str): The url to get the `by_write` split shakespeare.
+    """
+
+    def __init__(self,
+                 root,
+                 fraction,
+                 split_type,
+                 user,
+                 iid_user_fraction=0.1,
+                 train_test_split=0.9,
+                 minsample=10,
+                 num_class=80,
+                 num_of_client=100,
+                 class_per_client=2,
+                 setting_folder=None,
+                 seed=-1,
+                 **kwargs):
+        super(Shakespeare, self).__init__(root,
+                                          "shakespeare",
+                                          fraction,
+                                          split_type,
+                                          user,
+                                          iid_user_fraction,
+                                          train_test_split,
+                                          minsample,
+                                          num_class,
+                                          num_of_client,
+                                          class_per_client,
+                                          setting_folder,
+                                          seed)
+        self.raw_data_url = "http://www.gutenberg.org/files/100/old/1994-01-100.zip"
+        self.packaged_data_files = {
+            "shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/5qr9ozziy3yfzss/shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip",
+            "shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip": "https://dl.dropboxusercontent.com/s/4p7osgjd2pecsi3/shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip"
+        }
+        # Google drive ids.
+        # self.packaged_data_files = {
+        #     "shakespeare_niid_100_10_1_0.05_0.1_sample_0.9.zip": "1zvmNiUNu7r0h4t0jBhOJ204qyc61NvfJ",
+        #     "shakespeare_iid_100_10_1_0.05_0.1_sample_0.9.zip": "1Lb8n1zDtrj2DX_QkjNnL6DH5IrnYFdsR"
+        # }
+
+    def download_packaged_dataset_and_extract(self, filename):
+        file_path = download_url(self.packaged_data_files[filename], self.base_folder)
+        extract_archive(file_path, remove_finished=True)
+
+    def download_raw_file_and_extract(self):
+        raw_data_folder = os.path.join(self.base_folder, "raw_data")
+        if not os.path.exists(raw_data_folder):
+            os.makedirs(raw_data_folder)
+        elif os.listdir(raw_data_folder):
+            logger.info("raw file exists")
+            return
+        raw_data_path = download_url(self.raw_data_url, raw_data_folder)
+        extract_archive(raw_data_path, remove_finished=True)
+        os.rename(os.path.join(raw_data_folder, "100.txt"), os.path.join(raw_data_folder, "raw_data.txt"))
+        logger.info("raw file is downloaded")
+
+    def preprocess(self):
+        filename = os.path.join(self.base_folder, "raw_data", "raw_data.txt")
+        raw_data_folder = os.path.join(self.base_folder, "raw_data")
+        if not os.path.exists(raw_data_folder):
+            os.makedirs(raw_data_folder)
+        shakespeare_preprocess(filename, raw_data_folder)
+
+    def convert_data_to_json(self):
+        all_data_folder = os.path.join(self.base_folder, "all_data")
+        if not os.path.exists(all_data_folder):
+            os.makedirs(all_data_folder)
+        if not os.listdir(all_data_folder):
+            logger.info("converting data to .json format")
+            generated_all_data(self.base_folder)
+            logger.info("finished converting data to .json format")

+ 0 - 0
easyfl/datasets/shakespeare/utils/__init__.py


+ 17 - 0
easyfl/datasets/shakespeare/utils/gen_all_data.py

@@ -0,0 +1,17 @@
+"""
+These codes are adopted from LEAF with some modifications.
+"""
+
+import json
+import os
+
+from easyfl.datasets.shakespeare.utils.shake_utils import parse_data_in
+
+
+def generated_all_data(parent_path):
+    users_and_plays_path = os.path.join(parent_path, 'raw_data', 'users_and_plays.json')
+    txt_dir = os.path.join(parent_path, 'raw_data', 'by_play_and_character')
+    json_data = parse_data_in(txt_dir, users_and_plays_path)
+    json_path = os.path.join(parent_path, 'all_data', 'all_data.json')
+    with open(json_path, 'w') as outfile:
+        json.dump(json_data, outfile)

+ 183 - 0
easyfl/datasets/shakespeare/utils/preprocess_shakespeare.py

@@ -0,0 +1,183 @@
+"""Preprocesses the Shakespeare dataset for federated training.
+These codes are adopted from LEAF with some modifications.
+"""
+
+import collections
+import json
+import os
+import re
+
+RANDOM_SEED = 1234
+# Regular expression to capture an actors name, and line continuation
+CHARACTER_RE = re.compile(r'^  ([a-zA-Z][a-zA-Z ]*)\. (.*)')
+CONT_RE = re.compile(r'^    (.*)')
+# The Comedy of Errors has errors in its indentation so we need to use
+# different regular expressions.
+COE_CHARACTER_RE = re.compile(r'^([a-zA-Z][a-zA-Z ]*)\. (.*)')
+COE_CONT_RE = re.compile(r'^(.*)')
+
+
+def _match_character_regex(line, comedy_of_errors=False):
+    return (COE_CHARACTER_RE.match(line) if comedy_of_errors
+            else CHARACTER_RE.match(line))
+
+
+def _match_continuation_regex(line, comedy_of_errors=False):
+    return (
+        COE_CONT_RE.match(line) if comedy_of_errors else CONT_RE.match(line))
+
+
+def _split_into_plays(shakespeare_full):
+    """Splits the full data by play."""
+    # List of tuples (play_name, dict from character to list of lines)
+    plays = []
+    discarded_lines = []  # Track discarded lines.
+    slines = shakespeare_full.splitlines(True)[1:]
+
+    # skip contents, the sonnets, and all's well that ends well
+    author_count = 0
+    start_i = 0
+    for i, l in enumerate(slines):
+        if 'by William Shakespeare' in l:
+            author_count += 1
+        if author_count == 2:
+            start_i = i - 5
+            break
+    slines = slines[start_i:]
+
+    current_character = None
+    comedy_of_errors = False
+    for i, line in enumerate(slines):
+        # This marks the end of the plays in the file.
+        if i > 124195 - start_i:
+            break
+        # This is a pretty good heuristic for detecting the start of a new play:
+        if 'by William Shakespeare' in line:
+            current_character = None
+            characters = collections.defaultdict(list)
+            # The title will be 2, 3, 4, 5, 6, or 7 lines above "by William Shakespeare".
+            if slines[i - 2].strip():
+                title = slines[i - 2]
+            elif slines[i - 3].strip():
+                title = slines[i - 3]
+            elif slines[i - 4].strip():
+                title = slines[i - 4]
+            elif slines[i - 5].strip():
+                title = slines[i - 5]
+            elif slines[i - 6].strip():
+                title = slines[i - 6]
+            else:
+                title = slines[i - 7]
+            title = title.strip()
+
+            assert title, ('Parsing error on line %d. Expecting title 2 or 3 lines above.' % i)
+            comedy_of_errors = (title == 'THE COMEDY OF ERRORS')
+            # Degenerate plays are removed at the end of the method.
+            plays.append((title, characters))
+            continue
+        match = _match_character_regex(line, comedy_of_errors)
+        if match:
+            character, snippet = match.group(1), match.group(2)
+            # Some character names are written with multiple casings, e.g., SIR_Toby
+            # and SIR_TOBY. To normalize the character names, we uppercase each name.
+            # Note that this was not done in the original preprocessing and is a
+            # recent fix.
+            character = character.upper()
+            if not (comedy_of_errors and character.startswith('ACT ')):
+                characters[character].append(snippet)
+                current_character = character
+                continue
+            else:
+                current_character = None
+                continue
+        elif current_character:
+            match = _match_continuation_regex(line, comedy_of_errors)
+            if match:
+                if comedy_of_errors and match.group(1).startswith('<'):
+                    current_character = None
+                    continue
+                else:
+                    characters[current_character].append(match.group(1))
+                    continue
+        # Didn't consume the line.
+        line = line.strip()
+        if line and i > 2646:
+            # Before 2646 are the sonnets, which we expect to discard.
+            discarded_lines.append('%d:%s' % (i, line))
+    # Remove degenerate "plays".
+    return [play for play in plays if len(play[1]) > 1], discarded_lines
+
+
+def _remove_nonalphanumerics(filename):
+    return re.sub('\\W+', '_', filename)
+
+
+def play_and_character(play, character):
+    return _remove_nonalphanumerics((play + '_' + character).replace(' ', '_'))
+
+
+def _get_train_test_by_character(plays, test_fraction=0.2):
+    """
+      Splits character data into train and test sets.
+      if test_fraction <= 0, returns {} for all_test_examples
+      plays := list of (play, dict) tuples where play is a string and dict
+      is a dictionary with character names as keys
+    """
+    skipped_characters = 0
+    all_train_examples = collections.defaultdict(list)
+    all_test_examples = collections.defaultdict(list)
+
+    def add_examples(example_dict, example_tuple_list):
+        for play, character, sound_bite in example_tuple_list:
+            example_dict[play_and_character(
+                play, character)].append(sound_bite)
+
+    users_and_plays = {}
+    for play, characters in plays:
+        curr_characters = list(characters.keys())
+        for c in curr_characters:
+            users_and_plays[play_and_character(play, c)] = play
+        for character, sound_bites in characters.items():
+            examples = [(play, character, sound_bite)
+                        for sound_bite in sound_bites]
+            if len(examples) <= 2:
+                skipped_characters += 1
+                # Skip characters with fewer than 2 lines since we need at least one
+                # train and one test line.
+                continue
+            train_examples = examples
+            if test_fraction > 0:
+                num_test = max(int(len(examples) * test_fraction), 1)
+                train_examples = examples[:-num_test]
+                test_examples = examples[-num_test:]
+                assert len(test_examples) == num_test
+                assert len(train_examples) >= len(test_examples)
+                add_examples(all_test_examples, test_examples)
+            add_examples(all_train_examples, train_examples)
+    return users_and_plays, all_train_examples, all_test_examples
+
+
+def _write_data_by_character(examples, output_directory):
+    """Writes a collection of data files by play & character."""
+    if not os.path.exists(output_directory):
+        os.makedirs(output_directory)
+    for character_name, sound_bites in examples.items():
+        filename = os.path.join(output_directory, character_name + '.txt')
+        with open(filename, 'w') as output:
+            for sound_bite in sound_bites:
+                output.write(sound_bite + '\n')
+
+
+def shakespeare_preprocess(input_filename, output_directory):
+    print('Splitting .txt data between users')
+    input_filename = input_filename
+    with open(input_filename, 'r') as input_file:
+        shakespeare_full = input_file.read()
+    plays, discarded_lines = _split_into_plays(shakespeare_full)
+    print('Discarded %d lines' % len(discarded_lines))
+    users_and_plays, all_examples, _ = _get_train_test_by_character(plays, test_fraction=-1.0)
+    with open(os.path.join(output_directory, 'users_and_plays.json'), 'w') as ouf:
+        json.dump(users_and_plays, ouf)
+    _write_data_by_character(all_examples,
+                             os.path.join(output_directory,
+                                          'by_play_and_character/'))

+ 69 - 0
easyfl/datasets/shakespeare/utils/shake_utils.py

@@ -0,0 +1,69 @@
+"""
+Helper functions for preprocessing shakespeare data.
+
+These codes are adopted from LEAF with some modifications.
+"""
+import json
+import os
+import re
+
+
+def __txt_to_data(txt_dir, seq_length=80):
+    """Parses text file in given directory into data for next-character model.
+
+    Args:
+        txt_dir: path to text file
+        seq_length: length of strings in X
+    """
+    raw_text = ""
+    with open(txt_dir, 'r') as inf:
+        raw_text = inf.read()
+    raw_text = raw_text.replace('\n', ' ')
+    raw_text = re.sub(r"   *", r' ', raw_text)
+    dataX = []
+    dataY = []
+    for i in range(0, len(raw_text) - seq_length, 1):
+        seq_in = raw_text[i:i + seq_length]
+        seq_out = raw_text[i + seq_length]
+        dataX.append(seq_in)
+        dataY.append(seq_out)
+    return dataX, dataY
+
+
+def parse_data_in(data_dir, users_and_plays_path, raw=False):
+    """
+    returns dictionary with keys: users, num_samples, user_data
+    raw := bool representing whether to include raw text in all_data
+    if raw is True, then user_data key
+    removes users with no data
+    """
+    with open(users_and_plays_path, 'r') as inf:
+        users_and_plays = json.load(inf)
+    files = os.listdir(data_dir)
+    users = []
+    hierarchies = []
+    num_samples = []
+    user_data = {}
+    for f in files:
+        user = f[:-4]
+        passage = ''
+        filename = os.path.join(data_dir, f)
+        with open(filename, 'r') as inf:
+            passage = inf.read()
+        dataX, dataY = __txt_to_data(filename)
+        if (len(dataX) > 0):
+            users.append(user)
+            if raw:
+                user_data[user] = {'raw': passage}
+            else:
+                user_data[user] = {}
+            user_data[user]['x'] = dataX
+            user_data[user]['y'] = dataY
+            hierarchies.append(users_and_plays[user])
+            num_samples.append(len(dataY))
+    all_data = {}
+    all_data['users'] = users
+    all_data['hierarchies'] = hierarchies
+    all_data['num_samples'] = num_samples
+    all_data['user_data'] = user_data
+    return all_data

+ 350 - 0
easyfl/datasets/simulation.py

@@ -0,0 +1,350 @@
+import heapq
+import logging
+import math
+
+import numpy as np
+
+SIMULATE_IID = "iid"
+SIMULATE_NIID_DIR = "dir"
+SIMULATE_NIID_CLASS = "class"
+
+logger = logging.getLogger(__name__)
+
+
+def shuffle(data_x, data_y):
+    num_of_data = len(data_y)
+    data_x = np.array(data_x)
+    data_y = np.array(data_y)
+    index = [i for i in range(num_of_data)]
+    np.random.shuffle(index)
+    data_x = data_x[index]
+    data_y = data_y[index]
+    return data_x, data_y
+
+
+def equal_division(num_groups, data_x, data_y=None):
+    """Partition data into multiple clients with equal quantity.
+
+    Args:
+        num_groups (int): THe number of groups to partition to.
+        data_x (list[Object]): A list of elements to be divided.
+        data_y (list[Object], optional): A list of data labels to be divided together with the data.
+
+    Returns:
+        list[list]: A list where each element is a list of data of a group/client.
+        list[list]: A list where each element is a list of data label of a group/client.
+
+    Example:
+        >>> equal_division(3, list[range(9)])
+        >>> ([[0,4,2],[3,1,7],[6,5,8]], [])
+    """
+    if data_y is not None:
+        assert (len(data_x) == len(data_y))
+        data_x, data_y = shuffle(data_x, data_y)
+    else:
+        np.random.shuffle(data_x)
+    num_of_data = len(data_x)
+    assert num_of_data > 0
+    data_per_client = num_of_data // num_groups
+    large_group_num = num_of_data - num_groups * data_per_client
+    small_group_num = num_groups - large_group_num
+    splitted_data_x = []
+    splitted_data_y = []
+    for i in range(small_group_num):
+        base_index = data_per_client * i
+        splitted_data_x.append(data_x[base_index: base_index + data_per_client])
+        if data_y is not None:
+            splitted_data_y.append(data_y[base_index: base_index + data_per_client])
+    small_size = data_per_client * small_group_num
+    data_per_client += 1
+    for i in range(large_group_num):
+        base_index = small_size + data_per_client * i
+        splitted_data_x.append(data_x[base_index: base_index + data_per_client])
+        if data_y is not None:
+            splitted_data_y.append(data_y[base_index: base_index + data_per_client])
+
+    return splitted_data_x, splitted_data_y
+
+
+def quantity_hetero(weights, data_x, data_y=None):
+    """Partition data into multiple clients with different quantities.
+    The number of groups is the same as the number of elements of `weights`.
+    The quantity of each group depends on the values of `weights`.
+
+    Args:
+        weights (list[float]): The targeted distribution of data quantities.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+        data_x (list[Object]): A list of elements to be divided.
+        data_y (list[Object], optional): A list of data labels to be divided together with the data.
+
+    Returns:
+        list[list]: A list where each element is a list of data of a group/client.
+        list[list]: A list where each element is a list of data label of a group/client.
+        
+    Example:
+        >>> quantity_hetero([0.1, 0.2, 0.7], list(range(0, 10)))
+        >>> ([[4], [8, 9], [6, 0, 1, 7, 3, 2, 5]], [])
+    """
+    # This is due to the float number in python,
+    # e.g.sum([0.1,0.2,0.4,0.2,0.1]) is not exactly 1, but 1.0000000000000002.
+    assert (round(sum(weights), 3) == 1)
+
+    if data_y is not None:
+        assert (len(data_x) == len(data_y))
+        data_x, data_y = shuffle(data_x, data_y)
+    else:
+        np.random.shuffle(data_x)
+    data_size = len(data_x)
+
+    i = 0
+
+    splitted_data_x = []
+    splitted_data_y = []
+    for w in weights:
+        size = math.floor(data_size * w)
+        splitted_data_x.append(data_x[i:i + size])
+        if data_y is not None:
+            splitted_data_y.append(data_y[i:i + size])
+        i += size
+
+    parts = len(weights)
+    if i < data_size:
+        remain = data_size - i
+        for i in range(-remain, 0, 1):
+            splitted_data_x[(-i) % parts].append(data_x[i])
+            if data_y is not None:
+                splitted_data_y[(-i) % parts].append(data_y[i])
+
+    return splitted_data_x, splitted_data_y
+
+
+def iid(data_x, data_y, num_of_clients, x_dtype, y_dtype):
+    """Partition dataset into multiple clients with equal data quantity (difference is less than 1) randomly.
+
+    Args:
+        data_x (list[Object]): A list of data.
+        data_y (list[Object]): A list of dataset labels.
+        num_of_clients (int): The number of clients to partition to.
+        x_dtype (numpy.dtype): The type of data.
+        y_dtype (numpy.dtype): The type of data label.
+
+    Returns:
+        list[str]: A list of client ids.
+        dict: The partitioned data, key is client id, value is the client data. e.g., {'client_1': {'x': [data_x], 'y': [data_y]}}.
+    """
+    data_x, data_y = shuffle(data_x, data_y)
+    x_divided_list, y_divided_list = equal_division(num_of_clients, data_x, data_y)
+    clients = []
+    federated_data = {}
+    for i in range(num_of_clients):
+        client_id = "f%07.0f" % (i)
+        temp_client = {}
+        temp_client['x'] = np.array(x_divided_list[i]).astype(x_dtype)
+        temp_client['y'] = np.array(y_divided_list[i]).astype(y_dtype)
+        federated_data[client_id] = temp_client
+        clients.append(client_id)
+    return clients, federated_data
+
+
+def non_iid_dirichlet(data_x, data_y, num_of_clients, alpha, min_size, x_dtype, y_dtype):
+    """Partition dataset into multiple clients following the Dirichlet process.
+
+    Args:
+        data_x (list[Object]): A list of data.
+        data_y (list[Object]): A list of dataset labels.
+        num_of_clients (int): The number of clients to partition to.
+        alpha (float): The parameter for Dirichlet process simulation.
+        min_size (int): The minimum number of data size of a client.
+        x_dtype (numpy.dtype): The type of data.
+        y_dtype (numpy.dtype): The type of data label.
+
+    Returns:
+        list[str]: A list of client ids.
+        dict: The partitioned data, key is client id, value is the client data. e.g., {'client_1': {'x': [data_x], 'y': [data_y]}}.
+    """
+    n_train = data_x.shape[0]
+    current_min_size = 0
+    num_class = np.amax(data_y) + 1
+    data_size = data_y.shape[0]
+    net_dataidx_map = {}
+
+    while current_min_size < min_size:
+        idx_batch = [[] for _ in range(num_of_clients)]
+        for k in range(num_class):
+            idx_k = np.where(data_y == k)[0]
+            np.random.shuffle(idx_k)
+            proportions = np.random.dirichlet(np.repeat(alpha, num_of_clients))
+            # using the proportions from dirichlet, only selet those clients having data amount less than average
+            proportions = np.array(
+                [p * (len(idx_j) < data_size / num_of_clients) for p, idx_j in zip(proportions, idx_batch)])
+            # scale proportions
+            proportions = proportions / proportions.sum()
+            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
+            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
+            current_min_size = min([len(idx_j) for idx_j in idx_batch])
+
+    federated_data = {}
+    clients = []
+    for j in range(num_of_clients):
+        np.random.shuffle(idx_batch[j])
+        client_id = "f%07.0f" % j
+        clients.append(client_id)
+        temp = {}
+        temp['x'] = np.array(data_x[idx_batch[j]]).astype(x_dtype)
+        temp['y'] = np.array(data_y[idx_batch[j]]).astype(y_dtype)
+        federated_data[client_id] = temp
+        net_dataidx_map[client_id] = idx_batch[j]
+    print_data_distribution(data_y, net_dataidx_map)
+    return clients, federated_data
+
+
+def non_iid_class(data_x, data_y, class_per_client, num_of_clients, x_dtype, y_dtype, stack_x=True):
+    """Partition dataset into multiple clients based on label classes.
+    Each client contains [1, n] classes, where n is the number of classes of a dataset.
+
+    Note: Each class is divided into `ceil(class_per_client * num_of_clients / num_class)` parts
+        and each client chooses `class_per_client` parts from each class to construct its dataset.
+
+    Args:
+        data_x (list[Object]): A list of data.
+        data_y (list[Object]): A list of dataset labels.
+        class_per_client (int): The number of classes in each client.
+        num_of_clients (int): The number of clients to partition to.
+        x_dtype (numpy.dtype): The type of data.
+        y_dtype (numpy.dtype): The type of data label.
+        stack_x (bool, optional): A flag to indicate whether using np.vstack or append to construct dataset.
+
+    Returns:
+        list[str]: A list of client ids.
+        dict: The partitioned data, key is client id, value is the client data. e.g., {'client_1': {'x': [data_x], 'y': [data_y]}}.
+    """
+    num_class = np.amax(data_y) + 1
+    all_index = []
+    clients = []
+    data_index_map = {}
+    for i in range(num_class):
+        # get indexes for all data with current label i at index i in all_index
+        all_index.append(np.where(data_y == i)[0].tolist())
+
+    federated_data = {}
+
+    # total no. of parts
+    total_amount = class_per_client * num_of_clients
+    # no. of parts each class should be diveded into
+    parts_per_class = math.ceil(total_amount / num_class)
+
+    for i in range(num_of_clients):
+        client_id = "f%07.0f" % (i)
+        clients.append(client_id)
+        data_index_map[client_id] = []
+        data = {}
+        data['x'] = np.array([])
+        data['y'] = np.array([])
+        federated_data[client_id] = data
+
+    class_map = {}
+    parts_consumed = []
+    for i in range(num_class):
+        class_map[i], _ = equal_division(parts_per_class, all_index[i])
+        heapq.heappush(parts_consumed, (0, i))
+    for i in clients:
+        for j in range(class_per_client):
+            class_chosen = heapq.heappop(parts_consumed)
+            part_indexes = class_map[class_chosen[1]].pop(0)
+            if len(federated_data[i]['x']) != 0:
+                if stack_x:
+                    federated_data[i]['x'] = np.vstack((federated_data[i]['x'], data_x[part_indexes])).astype(x_dtype)
+                else:
+                    federated_data[i]['x'] = np.append(federated_data[i]['x'], data_x[part_indexes]).astype(x_dtype)
+                federated_data[i]['y'] = np.append(federated_data[i]['y'], data_y[part_indexes]).astype(y_dtype)
+            else:
+                federated_data[i]['x'] = data_x[part_indexes].astype(x_dtype)
+                federated_data[i]['y'] = data_y[part_indexes].astype(y_dtype)
+            heapq.heappush(parts_consumed, (class_chosen[0] + 1, class_chosen[1]))
+            data_index_map[i].extend(part_indexes)
+    print_data_distribution(data_y, data_index_map)
+    return clients, federated_data
+
+
+def data_simulation(data_x, data_y, num_of_clients, data_distribution, weights=None, alpha=0.5, min_size=10,
+                    class_per_client=1, stack_x=True):
+    """Simulate federated learning datasets by partitioning a data into multiple clients using different strategies.
+
+    Args:
+        data_x (list[Object]): A list of data.
+        data_y (list[Object]): A list of dataset labels.
+        num_of_clients (int): The number of clients to partition to.
+        data_distribution (str): The ways to partition the dataset, options:
+            `iid`: Partition dataset into multiple clients with equal quantity (difference is less than 1) randomly.
+            `dir`: partition dataset into multiple clients following the Dirichlet process.
+            `class`: partition dataset into multiple clients based on classes.
+        weights: list, for simulating data quantity heterogeneity
+            If None, each client are simulated with same data quantity
+            Note: num_of_clients should be divisible by len(weights)
+        weights (list[float], optional): The targeted distribution of data quantities.
+            The values should sum up to 1. e.g., [0.1, 0.2, 0.7].
+            When `weights=None`, the data quantity of clients only depends on data_distribution.
+        alpha (float, optional): The parameter for Dirichlet process simulation.
+            It is only applicable when data_distribution is `dir`.
+        min_size (int, optional): The minimum number of data size of a client.
+            It is only applicable when data_distribution is `dir`.
+        class_per_client (int): The number of classes in each client.
+            It is only applicable when data_distribution is `class`.
+        stack_x (bool, optional): A flag to indicate whether using np.vstack or append to construct dataset.
+            It is only applicable when data_distribution is `class`.
+
+    Raise:
+        ValueError: When the simulation method `data_distribution` is not supported.
+
+    Returns:
+        list[str]: A list of client ids.
+        dict: The partitioned data, key is client id, value is the client data. e.g., {'client_1': {'x': [data_x], 'y': [data_y]}}.
+    """
+    data_x = np.array(data_x)
+    data_y = np.array(data_y)
+    x_dtype = data_x.dtype
+    y_dtype = data_y.dtype
+    if weights is not None:
+        assert num_of_clients % len(weights) == 0
+        num_of_clients = num_of_clients // len(weights)
+
+    if data_distribution == SIMULATE_IID:
+        group_client_list, group_federated_data = iid(data_x, data_y, num_of_clients, x_dtype, y_dtype)
+    elif data_distribution == SIMULATE_NIID_DIR:
+        group_client_list, group_federated_data = non_iid_dirichlet(data_x, data_y, num_of_clients, alpha, min_size,
+                                                                    x_dtype, y_dtype)
+    elif data_distribution == SIMULATE_NIID_CLASS:
+        group_client_list, group_federated_data = non_iid_class(data_x, data_y, class_per_client, num_of_clients,
+                                                                x_dtype,
+                                                                y_dtype, stack_x=stack_x)
+    else:
+        raise ValueError("Simulation type not supported")
+    if weights is None:
+        return group_client_list, group_federated_data
+
+    clients = []
+    federated_data = {}
+    cur_key = 0
+    for i in group_client_list:
+        current_client = group_federated_data[i]
+        input_lists, label_lists = quantity_hetero(weights, current_client['x'], current_client['y'])
+        for j in range(len(input_lists)):
+            client_id = "f%07.0f" % (cur_key)
+            temp_client = {}
+            temp_client['x'] = np.array(input_lists[j]).astype(x_dtype)
+            temp_client['y'] = np.array(label_lists[j]).astype(y_dtype)
+            federated_data[client_id] = temp_client
+            clients.append(client_id)
+            cur_key += 1
+    return clients, federated_data
+
+
+def print_data_distribution(data_y, data_index_map):
+    """Log the distribution of client datasets."""
+    data_distribution = {}
+    for index, dataidx in data_index_map.items():
+        unique_values, counts = np.unique(data_y[dataidx], return_counts=True)
+        distribution = {unique_values[i]: counts[i] for i in range(len(unique_values))}
+        data_distribution[index] = distribution
+    logger.info(data_distribution)
+    return data_distribution

+ 0 - 0
easyfl/datasets/utils/__init__.py


+ 158 - 0
easyfl/datasets/utils/base_dataset.py

@@ -0,0 +1,158 @@
+import logging
+import os
+from abc import abstractmethod
+
+from easyfl.datasets.utils.remove_users import remove
+from easyfl.datasets.utils.sample import sample, extreme
+from easyfl.datasets.utils.split_data import split_train_test
+
+logger = logging.getLogger(__name__)
+
+CIFAR10 = "cifar10"
+CIFAR100 = "cifar100"
+
+
+class BaseDataset(object):
+    """The internal base dataset implementation.
+
+    Args:
+        root (str): The root directory where datasets stored.
+        dataset_name (str): The name of the dataset.
+        fraction (float): The fraction of the data chosen from the raw data to use.
+        num_of_clients (int): The targeted number of clients to construct.
+        split_type (str): The type of statistical simulation, options: iid, dir, and class.
+            `iid` means independent and identically distributed data.
+            `niid` means non-independent and identically distributed data for Femnist and Shakespeare.
+            `dir` means using Dirichlet process to simulate non-iid data, for CIFAR-10 and CIFAR-100 datasets.
+            `class` means partitioning the dataset by label classes, for datasets like CIFAR-10, CIFAR-100.
+        minsample (int): The minimal number of samples in each client.
+            It is applicable for LEAF datasets and dir simulation of CIFAR-10 and CIFAR-100.
+        class_per_client (int): The number of classes in each client. Only applicable when the split_type is 'class'.
+        iid_user_fraction (float): The fraction of the number of clients used when the split_type is 'iid'.
+        user (bool): A flag to indicate whether partition users of the dataset into train-test groups.
+            Only applicable to LEAF datasets.
+            True means partitioning users of the dataset into train-test groups.
+            False means partitioning each users' samples into train-test groups.
+        train_test_split (float): The fraction of data for training; the rest are for testing.
+            e.g., 0.9 means 90% of data are used for training and 10% are used for testing.
+        num_class: The number of classes in this dataset.
+        seed: Random seed.
+    """
+
+    def __init__(self,
+                 root,
+                 dataset_name,
+                 fraction,
+                 split_type,
+                 user,
+                 iid_user_fraction,
+                 train_test_split,
+                 minsample,
+                 num_class,
+                 num_of_client,
+                 class_per_client,
+                 setting_folder,
+                 seed=-1,
+                 **kwargs):
+        # file_path = os.path.dirname(os.path.realpath(__file__))
+        # self.base_folder = os.path.join(os.path.dirname(file_path), "data", dataset_name)
+        self.base_folder = root
+        self.dataset_name = dataset_name
+        self.fraction = fraction
+        self.split_type = split_type  # iid, niid, class
+        self.user = user
+        self.iid_user_fraction = iid_user_fraction
+        self.train_test_split = train_test_split
+        self.minsample = minsample
+        self.num_class = num_class
+        self.num_of_client = num_of_client
+        self.class_per_client = class_per_client
+        self.seed = seed
+        if split_type == "iid":
+            assert self.user == False
+            self.iid = True
+        elif split_type == "niid":
+            # if niid, user can be either True or False
+            self.iid = False
+
+        self.setting_folder = setting_folder
+        self.data_folder = os.path.join(self.base_folder, self.setting_folder)
+
+    @abstractmethod
+    def download_packaged_dataset_and_extract(self, filename):
+        raise NotImplementedError("download_packaged_dataset_and_extract not implemented")
+
+    @abstractmethod
+    def download_raw_file_and_extract(self):
+        raise NotImplementedError("download_raw_file_and_extract not implemented")
+
+    @abstractmethod
+    def preprocess(self):
+        raise NotImplementedError("preprocess not implemented")
+
+    @abstractmethod
+    def convert_data_to_json(self):
+        raise NotImplementedError("convert_data_to_json not implemented")
+
+    @staticmethod
+    def get_setting_folder(dataset, split_type, num_of_client, min_size, class_per_client,
+                           fraction, iid_fraction, user_str, train_test_split, alpha=None, weights=None):
+        if dataset == CIFAR10 or dataset == CIFAR100:
+            return "{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client, alpha,
+                                                 1 if weights else 0)
+        else:
+            return "{}_{}_{}_{}_{}_{}_{}_{}_{}".format(dataset, split_type, num_of_client, min_size, class_per_client,
+                                                       fraction, iid_fraction, user_str, train_test_split)
+
+    def setup(self):
+        self.download_raw_file_and_extract()
+        self.preprocess()
+        self.convert_data_to_json()
+
+    def sample_customized(self):
+        meta_folder = os.path.join(self.base_folder, "meta")
+        if not os.path.exists(meta_folder):
+            os.makedirs(meta_folder)
+        sample_folder = os.path.join(self.data_folder, "sampled_data")
+        if not os.path.exists(sample_folder):
+            os.makedirs(sample_folder)
+        if not os.listdir(sample_folder):
+            sample(self.base_folder, self.data_folder, meta_folder, self.fraction, self.iid, self.iid_user_fraction, self.seed)
+
+    def sample_extreme(self):
+        meta_folder = os.path.join(self.base_folder, "meta")
+        if not os.path.exists(meta_folder):
+            os.makedirs(meta_folder)
+        sample_folder = os.path.join(self.data_folder, "sampled_data")
+        if not os.path.exists(sample_folder):
+            os.makedirs(sample_folder)
+        if not os.listdir(sample_folder):
+            extreme(self.base_folder, self.data_folder, meta_folder, self.fraction, self.num_class, self.num_of_client, self.class_per_client, self.seed)
+
+    def remove_unqualified_user(self):
+        rm_folder = os.path.join(self.data_folder, "rem_user_data")
+        if not os.path.exists(rm_folder):
+            os.makedirs(rm_folder)
+        if not os.listdir(rm_folder):
+            remove(self.data_folder, self.dataset_name, self.minsample)
+
+    def split_train_test_set(self):
+        meta_folder = os.path.join(self.base_folder, "meta")
+        train = os.path.join(self.data_folder, "train")
+        if not os.path.exists(train):
+            os.makedirs(train)
+        test = os.path.join(self.data_folder, "test")
+        if not os.path.exists(test):
+            os.makedirs(test)
+        if not os.listdir(train) and not os.listdir(test):
+            split_train_test(self.data_folder, meta_folder, self.dataset_name, self.user, self.train_test_split, self.seed)
+
+    def sampling(self):
+        if self.split_type == "iid":
+            self.sample_customized()
+        elif self.split_type == "niid":
+            self.sample_customized()
+        elif self.split_type == "class":
+            self.sample_extreme()
+        self.remove_unqualified_user()
+        self.split_train_test_set()

+ 2 - 0
easyfl/datasets/utils/constants.py

@@ -0,0 +1,2 @@
+DATASETS = ['sent140', 'femnist', 'shakespeare', 'celeba', 'synthetic']
+SEED_FILES = {'sampling': 'sampling_seed.txt', 'split': 'split_seed.txt'}

+ 176 - 0
easyfl/datasets/utils/download.py

@@ -0,0 +1,176 @@
+"""
+These codes are adopted from torchvison with some modifications.
+"""
+import gzip
+import hashlib
+import logging
+import os
+import tarfile
+import zipfile
+
+import requests
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+def gen_bar_updater():
+    pbar = tqdm(total=None)
+
+    def bar_update(count, block_size, total_size):
+        if pbar.total is None and total_size:
+            pbar.total = total_size
+        progress_bytes = count * block_size
+        pbar.update(progress_bytes - pbar.n)
+
+    return bar_update
+
+
+def calculate_md5(fpath, chunk_size=1024 * 1024):
+    md5 = hashlib.md5()
+    with open(fpath, 'rb') as f:
+        for chunk in iter(lambda: f.read(chunk_size), b''):
+            md5.update(chunk)
+    return md5.hexdigest()
+
+
+def check_md5(fpath, md5, **kwargs):
+    return md5 == calculate_md5(fpath, **kwargs)
+
+
+def check_integrity(fpath, md5=None):
+    if not os.path.isfile(fpath):
+        return False
+    if md5 is None:
+        return True
+    return check_md5(fpath, md5)
+
+
+def download_url(url, root, filename=None, md5=None):
+    """Download a file from a url and place it in root.
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under. If None, use the basename of the URL
+    """
+    import urllib.request
+    import urllib.error
+
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    os.makedirs(root, exist_ok=True)
+
+    # check if file is already present locally
+    if check_integrity(fpath, md5):
+        logger.info("Using downloaded and verified file: " + fpath)
+        return fpath
+    else:  # download the file
+        try:
+            logger.info("Downloading {} to {}".format(url, fpath))
+            urllib.request.urlretrieve(
+                url, fpath,
+                reporthook=gen_bar_updater()
+            )
+        except (urllib.error.URLError, IOError) as e:
+            if url[:5] != 'https':
+                raise e
+            url = url.replace('https:', 'http:')
+            logger.info("Failed download. Trying https -> http instead."
+                        "Downloading {} to {}".format(url, fpath))
+            urllib.request.urlretrieve(
+                url, fpath,
+                reporthook=gen_bar_updater()
+            )
+
+        # check integrity of downloaded file
+        if not check_integrity(fpath, md5):
+            raise RuntimeError("File not found or corrupted.")
+    return fpath
+
+
+def download_from_google_drive(id, destination):
+    # taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
+    URL = "https://docs.google.com/uc?export=download"
+
+    session = requests.Session()
+
+    response = session.get(URL, params={'id': id}, stream=True)
+    token = get_confirm_token(response)
+
+    if token:
+        params = {'id': id, 'confirm': token}
+        response = session.get(URL, params=params, stream=True)
+    else:
+        raise FileNotFoundError("Google drive file id does not exist")
+    save_response_content(response, destination)
+
+
+def get_confirm_token(response):
+    for key, value in response.cookies.items():
+        if key.startswith('download_warning'):
+            return value
+
+    return None
+
+
+def save_response_content(response, destination):
+    CHUNK_SIZE = 32768
+
+    with open(destination, "wb") as f:
+        for chunk in response.iter_content(CHUNK_SIZE):
+            if chunk:  # filter out keep-alive new chunks
+                f.write(chunk)
+
+
+def _is_tarxz(filename):
+    return filename.endswith(".tar.xz")
+
+
+def _is_tar(filename):
+    return filename.endswith(".tar")
+
+
+def _is_targz(filename):
+    return filename.endswith(".tar.gz")
+
+
+def _is_tgz(filename):
+    return filename.endswith(".tgz")
+
+
+def _is_gzip(filename):
+    return filename.endswith(".gz") and not filename.endswith(".tar.gz")
+
+
+def _is_zip(filename):
+    return filename.endswith(".zip")
+
+
+def extract_archive(from_path, to_path=None, remove_finished=False):
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    if _is_tar(from_path):
+        with tarfile.open(from_path, 'r') as tar:
+            tar.extractall(path=to_path)
+    elif _is_targz(from_path) or _is_tgz(from_path):
+        with tarfile.open(from_path, 'r:gz') as tar:
+            tar.extractall(path=to_path)
+    elif _is_tarxz(from_path):
+        with tarfile.open(from_path, 'r:xz') as tar:
+            tar.extractall(path=to_path)
+    elif _is_gzip(from_path):
+        to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
+        with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
+            out_f.write(zip_f.read())
+    elif _is_zip(from_path):
+        with zipfile.ZipFile(from_path, 'r') as z:
+            z.extractall(to_path)
+    else:
+        raise ValueError("file format not supported")
+
+    if remove_finished:
+        os.remove(from_path)

+ 62 - 0
easyfl/datasets/utils/remove_users.py

@@ -0,0 +1,62 @@
+"""
+Removes users with less than the given number of samples.
+
+These codes are adopted from LEAF with some modifications.
+"""
+
+import json
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+
+def remove(setting_folder, dataset, min_samples):
+    parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+    dir = os.path.join(parent_path, dataset, "data")
+    subdir = os.path.join(dir, setting_folder, "sampled_data")
+    files = []
+    if os.path.exists(subdir):
+        files = os.listdir(subdir)
+    if len(files) == 0:
+        subdir = os.path.join(dir, "all_data")
+        files = os.listdir(subdir)
+    files = [f for f in files if f.endswith(".json")]
+
+    for f in files:
+        users = []
+        hierarchies = []
+        num_samples = []
+        user_data = {}
+
+        file_dir = os.path.join(subdir, f)
+        with open(file_dir, "r") as inf:
+            data = json.load(inf)
+
+        num_users = len(data["users"])
+        for i in range(num_users):
+            curr_user = data["users"][i]
+            curr_hierarchy = None
+            if "hierarchies" in data:
+                curr_hierarchy = data["hierarchies"][i]
+            curr_num_samples = data["num_samples"][i]
+            if (curr_num_samples >= min_samples):
+                user_data[curr_user] = data["user_data"][curr_user]
+                users.append(curr_user)
+                if curr_hierarchy is not None:
+                    hierarchies.append(curr_hierarchy)
+                num_samples.append(data["num_samples"][i])
+
+        all_data = {}
+        all_data["users"] = users
+        if len(hierarchies) == len(users):
+            all_data["hierarchies"] = hierarchies
+        all_data["num_samples"] = num_samples
+        all_data["user_data"] = user_data
+
+        file_name = "{}_keep_{}.json".format((f[:-5]), min_samples)
+        ouf_dir = os.path.join(dir, setting_folder, "rem_user_data", file_name)
+
+        logger.info("writing {}".format(file_name))
+        with open(ouf_dir, "w") as outfile:
+            json.dump(all_data, outfile)

+ 274 - 0
easyfl/datasets/utils/sample.py

@@ -0,0 +1,274 @@
+"""
+These codes are adopted from LEAF with some modifications.
+
+Samples from all raw data;
+by default samples in a non-iid manner; namely, randomly selects users from 
+raw data until their cumulative amount of data exceeds the given number of 
+datapoints to sample (specified by --fraction argument);
+ordering of original data points is not preserved in sampled data
+"""
+
+import json
+import logging
+import os
+import random
+import time
+from collections import OrderedDict
+
+from easyfl.datasets.simulation import non_iid_class
+from easyfl.datasets.utils.constants import SEED_FILES
+from easyfl.datasets.utils.util import iid_divide
+
+logger = logging.getLogger(__name__)
+
+
+def extreme(data_dir, data_folder, metafile, fraction, num_class=62, num_of_client=100, class_per_client=2, seed=-1):
+    """
+    Note: for extreme split, there are two ways, one is divide each class into parts and then distribute to the clients;
+    The second way is to let clients to go through classes to get a part of the data; Current version is the latter one, we 
+    can also provide the previous one (the one we adopt in CIFA10); If (num_of_client*class_per_client)%num_class, there is no 
+    difference(assume each class is equal), otherwise, how to deal with some remain parts is a question to discuss. (currently,
+    the method will just give the remain part to the next client coming for collection, which may make the last clients have more
+    than class_per_client;)
+    """
+    logger.info("------------------------------")
+    logger.info("sampling data")
+
+    subdir = os.path.join(data_dir, 'all_data')
+    files = os.listdir(subdir)
+    files = [f for f in files if f.endswith('.json')]
+
+    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
+    logger.info("Using seed {}".format(rng_seed))
+    rng = random.Random(rng_seed)
+
+    logger.info(metafile)
+    if metafile is not None:
+        seed_fname = os.path.join(metafile, SEED_FILES['sampling'])
+        with open(seed_fname, 'w+') as f:
+            f.write("# sampling_seed used by sampling script - supply as "
+                    "--smplseed to preprocess.sh or --seed to utils/sample.py\n")
+            f.write(str(rng_seed))
+        logger.info("- random seed written out to {file}".format(file=seed_fname))
+    else:
+        logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
+    new_user_count = 0  # for iid case
+    all_users = []
+    all_user_data = {}
+    for f in files:
+        file_dir = os.path.join(subdir, f)
+        with open(file_dir, 'r') as inf:
+            data = json.load(inf, object_pairs_hook=OrderedDict)
+
+        num_users = len(data['users'])
+
+        tot_num_samples = sum(data['num_samples'])
+        num_new_samples = int(fraction * tot_num_samples)
+
+        raw_list = list(data['user_data'].values())
+        raw_x = [elem['x'] for elem in raw_list]
+        raw_y = [elem['y'] for elem in raw_list]
+        x_list = [item for sublist in raw_x for item in sublist]  # flatten raw_x
+        y_list = [item for sublist in raw_y for item in sublist]  # flatten raw_y
+        num_new_users = num_users
+
+        indices = [i for i in range(tot_num_samples)]
+        new_indices = rng.sample(indices, num_new_samples)
+        users = [str(i + new_user_count) for i in range(num_new_users)]
+        all_users.extend(users)
+        user_data = {}
+        for user in users:
+            user_data[user] = {'x': [], 'y': []}
+        all_x_samples = [x_list[i] for i in new_indices]
+        all_y_samples = [y_list[i] for i in new_indices]
+        x_groups = iid_divide(all_x_samples, num_new_users)
+        y_groups = iid_divide(all_y_samples, num_new_users)
+        for i in range(num_new_users):
+            user_data[users[i]]['x'] = x_groups[i]
+            user_data[users[i]]['y'] = y_groups[i]
+        all_user_data.update(user_data)
+
+        num_samples = [len(user_data[u]['y']) for u in users]
+        new_user_count += num_new_users
+
+    allx = []
+    ally = []
+    for i in all_users:
+        allx.extend(all_user_data[i]['x'])
+        ally.extend(all_user_data[i]['y'])
+    clients, all_user_data = non_iid_class(x_list, y_list, class_per_client, num_of_client)
+
+    # ------------
+    # create .json file
+    all_num_samples = []
+    for i in clients:
+        all_num_samples.append(len(all_user_data[i]['y']))
+    all_data = {}
+    all_data['users'] = clients
+    all_data['num_samples'] = all_num_samples
+    all_data['user_data'] = all_user_data
+
+    slabel = ''
+
+    arg_frac = str(fraction)
+    arg_frac = arg_frac[2:]
+    arg_label = arg_frac
+    file_name = '%s_%s_%s.json' % ("class", slabel, arg_label)
+    ouf_dir = os.path.join(data_folder, 'sampled_data', file_name)
+
+    logger.info("writing {}".format(file_name))
+    with open(ouf_dir, 'w') as outfile:
+        json.dump(all_data, outfile)
+
+
+def sample(data_dir, data_folder, metafile, fraction, iid, iid_user_fraction=0.01, seed=-1):
+    logger.info("------------------------------")
+    logger.info("sampling data")
+    subdir = os.path.join(data_dir, 'all_data')
+    files = os.listdir(subdir)
+    files = [f for f in files if f.endswith('.json')]
+
+    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
+    logger.info("Using seed {}".format(rng_seed))
+    rng = random.Random(rng_seed)
+
+    logger.info(metafile)
+    if metafile is not None:
+        seed_fname = os.path.join(metafile, SEED_FILES['sampling'])
+        with open(seed_fname, 'w+') as f:
+            f.write("# sampling_seed used by sampling script - supply as "
+                    "--smplseed to preprocess.sh or --seed to utils/sample.py\n")
+            f.write(str(rng_seed))
+        logger.info("- random seed written out to {file}".format(file=seed_fname))
+    else:
+        logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
+
+    new_user_count = 0  # for iid case
+    for f in files:
+        file_dir = os.path.join(subdir, f)
+        with open(file_dir, 'r') as inf:
+            # Load data into an OrderedDict, to prevent ordering changes
+            # and enable reproducibility
+            data = json.load(inf, object_pairs_hook=OrderedDict)
+
+        num_users = len(data['users'])
+
+        tot_num_samples = sum(data['num_samples'])
+        num_new_samples = int(fraction * tot_num_samples)
+
+        hierarchies = None
+
+        if iid:
+            # iid in femnist is to put all data together, and then split them according to
+            # iid_user_fraction * num_users numbers of clients evenly
+            raw_list = list(data['user_data'].values())
+            raw_x = [elem['x'] for elem in raw_list]
+            raw_y = [elem['y'] for elem in raw_list]
+            x_list = [item for sublist in raw_x for item in sublist]  # flatten raw_x
+            y_list = [item for sublist in raw_y for item in sublist]  # flatten raw_y
+
+            num_new_users = int(round(iid_user_fraction * num_users))
+            if num_new_users == 0:
+                num_new_users += 1
+
+            indices = [i for i in range(tot_num_samples)]
+            new_indices = rng.sample(indices, num_new_samples)
+            users = ["f%07.0f" % (i + new_user_count) for i in range(num_new_users)]
+
+            user_data = {}
+            for user in users:
+                user_data[user] = {'x': [], 'y': []}
+            all_x_samples = [x_list[i] for i in new_indices]
+            all_y_samples = [y_list[i] for i in new_indices]
+            x_groups = iid_divide(all_x_samples, num_new_users)
+            y_groups = iid_divide(all_y_samples, num_new_users)
+            for i in range(num_new_users):
+                user_data[users[i]]['x'] = x_groups[i]
+                user_data[users[i]]['y'] = y_groups[i]
+
+            num_samples = [len(user_data[u]['y']) for u in users]
+
+            new_user_count += num_new_users
+
+        else:
+            # niid's fraction in femnist is to choose some clients, one by one,
+            # until the data size meets the fration * total data size
+            ctot_num_samples = 0
+
+            users = data['users']
+            users_and_hiers = None
+            if 'hierarchies' in data:
+                users_and_hiers = list(zip(users, data['hierarchies']))
+                rng.shuffle(users_and_hiers)
+            else:
+                rng.shuffle(users)
+            user_i = 0
+            num_samples = []
+            user_data = {}
+
+            if 'hierarchies' in data:
+                hierarchies = []
+
+            while ctot_num_samples < num_new_samples:
+                hierarchy = None
+                if users_and_hiers is not None:
+                    user, hier = users_and_hiers[user_i]
+                else:
+                    user = users[user_i]
+
+                cdata = data['user_data'][user]
+
+                cnum_samples = len(data['user_data'][user]['y'])
+
+                if ctot_num_samples + cnum_samples > num_new_samples:
+                    cnum_samples = num_new_samples - ctot_num_samples
+                    indices = [i for i in range(cnum_samples)]
+                    new_indices = rng.sample(indices, cnum_samples)
+                    x = []
+                    y = []
+                    for i in new_indices:
+                        x.append(data['user_data'][user]['x'][i])
+                        y.append(data['user_data'][user]['y'][i])
+                    cdata = {'x': x, 'y': y}
+
+                if 'hierarchies' in data:
+                    hierarchies.append(hier)
+
+                num_samples.append(cnum_samples)
+                user_data[user] = cdata
+
+                ctot_num_samples += cnum_samples
+                user_i += 1
+
+            if 'hierarchies' in data:
+                users = [u for u, h in users_and_hiers][:user_i]
+            else:
+                users = users[:user_i]
+
+        # ------------
+        # create .json file
+
+        all_data = {}
+        all_data['users'] = users
+        if hierarchies is not None:
+            all_data['hierarchies'] = hierarchies
+        all_data['num_samples'] = num_samples
+        all_data['user_data'] = user_data
+
+        slabel = 'niid'
+        if iid:
+            slabel = 'iid'
+
+        arg_frac = str(fraction)
+        arg_frac = arg_frac[2:]
+        arg_nu = str(iid_user_fraction)
+        arg_nu = arg_nu[2:]
+        arg_label = arg_frac
+        if iid:
+            arg_label = '%s_%s' % (arg_nu, arg_label)
+        file_name = '%s_%s_%s.json' % ((f[:-5]), slabel, arg_label)
+        ouf_dir = os.path.join(data_folder, 'sampled_data', file_name)
+
+        logger.info('writing %s' % file_name)
+        with open(ouf_dir, 'w') as outfile:
+            json.dump(all_data, outfile)

+ 235 - 0
easyfl/datasets/utils/split_data.py

@@ -0,0 +1,235 @@
+"""
+These codes are adopted from LEAF with some modifications.
+
+Splits data into train and test sets.
+"""
+
+import json
+import logging
+import os
+import random
+import sys
+import time
+from collections import OrderedDict
+
+from easyfl.datasets.utils.constants import SEED_FILES
+
+logger = logging.getLogger(__name__)
+
+
+def create_jsons_for(dir, setting_folder, user_files, which_set, max_users, include_hierarchy, subdir, arg_label):
+    """Used in split-by-user case"""
+    user_count = 0
+    json_index = 0
+    users = []
+    if include_hierarchy:
+        hierarchies = []
+    else:
+        hierarchies = None
+    num_samples = []
+    user_data = {}
+    for (i, t) in enumerate(user_files):
+        if include_hierarchy:
+            (u, h, ns, f) = t
+        else:
+            (u, ns, f) = t
+
+        file_dir = os.path.join(subdir, f)
+        with open(file_dir, 'r') as inf:
+            data = json.load(inf)
+
+        users.append(u)
+        if include_hierarchy:
+            hierarchies.append(h)
+        num_samples.append(ns)
+        user_data[u] = data['user_data'][u]
+        user_count += 1
+
+        if (user_count == max_users) or (i == len(user_files) - 1):
+
+            all_data = {}
+            all_data['users'] = users
+            if include_hierarchy:
+                all_data['hierarchies'] = hierarchies
+            all_data['num_samples'] = num_samples
+            all_data['user_data'] = user_data
+
+            data_i = f.find('data')
+            num_i = data_i + 5
+            num_to_end = f[num_i:]
+            param_i = num_to_end.find('_')
+            param_to_end = '.json'
+            if param_i != -1:
+                param_to_end = num_to_end[param_i:]
+            nf = "{}_{}{}".format(f[:(num_i - 1)], json_index, param_to_end)
+            file_name = '{}_{}_{}.json'.format((nf[:-5]), which_set, arg_label)
+            ouf_dir = os.path.join(dir, setting_folder, which_set, file_name)
+
+            logger.info('writing {}'.format(file_name))
+            with open(ouf_dir, 'w') as outfile:
+                json.dump(all_data, outfile)
+
+            user_count = 0
+            json_index += 1
+            users = []
+            if include_hierarchy:
+                hierarchies = []
+            num_samples = []
+            user_data = {}
+
+
+def split_train_test(setting_folder, metafile, name, user, frac, seed):
+    logger.info("------------------------------")
+    logger.info("generating training and test sets")
+
+    parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+    dir = os.path.join(parent_path, name, 'data')
+    subdir = os.path.join(dir, setting_folder, 'rem_user_data')
+    files = []
+    if os.path.exists(subdir):
+        files = os.listdir(subdir)
+    if len(files) == 0:
+        subdir = os.path.join(dir, setting_folder, 'sampled_data')
+        if os.path.exists(subdir):
+            files = os.listdir(subdir)
+    if len(files) == 0:
+        subdir = os.path.join(dir, 'all_data')
+        files = os.listdir(subdir)
+    files = [f for f in files if f.endswith('.json')]
+
+    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
+    rng = random.Random(rng_seed)
+    if metafile is not None:
+        seed_fname = os.path.join(metafile, SEED_FILES['split'])
+        with open(seed_fname, 'w+') as f:
+            f.write("# split_seed used by sampling script - supply as "
+                    "--spltseed to preprocess.sh or --seed to utils/split_data.py\n")
+            f.write(str(rng_seed))
+        logger.info("- random seed written out to {file}".format(file=seed_fname))
+    else:
+        logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
+
+    arg_label = str(frac)
+    arg_label = arg_label[2:]
+
+    # check if data contains information on hierarchies
+    file_dir = os.path.join(subdir, files[0])
+    with open(file_dir, 'r') as inf:
+        data = json.load(inf)
+    include_hierarchy = 'hierarchies' in data
+
+    if (user):
+        logger.info("splitting data by user")
+
+        # 1 pass through all the json files to instantiate arr
+        # containing all possible (user, .json file name) tuples
+        user_files = []
+        for f in files:
+            file_dir = os.path.join(subdir, f)
+            with open(file_dir, 'r') as inf:
+                # Load data into an OrderedDict, to prevent ordering changes
+                # and enable reproducibility
+                data = json.load(inf, object_pairs_hook=OrderedDict)
+            if include_hierarchy:
+                user_files.extend([(u, h, ns, f) for (u, h, ns) in
+                                   zip(data['users'], data['hierarchies'], data['num_samples'])])
+            else:
+                user_files.extend([(u, ns, f) for (u, ns) in
+                                   zip(data['users'], data['num_samples'])])
+
+        # randomly sample from user_files to pick training set users
+        num_users = len(user_files)
+        num_train_users = int(frac * num_users)
+        indices = [i for i in range(num_users)]
+        train_indices = rng.sample(indices, num_train_users)
+        train_blist = [False for i in range(num_users)]
+        for i in train_indices:
+            train_blist[i] = True
+        train_user_files = []
+        test_user_files = []
+        for i in range(num_users):
+            if (train_blist[i]):
+                train_user_files.append(user_files[i])
+            else:
+                test_user_files.append(user_files[i])
+
+        max_users = sys.maxsize
+        if name == 'femnist':
+            max_users = 50  # max number of users per json file
+        create_jsons_for(dir, setting_folder, train_user_files, 'train', max_users, include_hierarchy, subdir,
+                         arg_label)
+        create_jsons_for(dir, setting_folder, test_user_files, 'test', max_users, include_hierarchy, subdir, arg_label)
+
+    else:
+        logger.info("splitting data by sample")
+
+        for f in files:
+            file_dir = os.path.join(subdir, f)
+            with open(file_dir, 'r') as inf:
+                # Load data into an OrderedDict, to prevent ordering changes
+                # and enable reproducibility
+                data = json.load(inf, object_pairs_hook=OrderedDict)
+
+            num_samples_train = []
+            user_data_train = {}
+            num_samples_test = []
+            user_data_test = {}
+
+            user_indices = []  # indices of users in data['users'] that are not deleted
+
+            for i, u in enumerate(data['users']):
+                user_data_train[u] = {'x': [], 'y': []}
+                user_data_test[u] = {'x': [], 'y': []}
+
+                curr_num_samples = len(data['user_data'][u]['y'])
+                if curr_num_samples >= 2:
+                    user_indices.append(i)
+
+                    # ensures number of train and test samples both >= 1
+                    num_train_samples = max(1, int(frac * curr_num_samples))
+                    if curr_num_samples == 2:
+                        num_train_samples = 1
+
+                    num_test_samples = curr_num_samples - num_train_samples
+                    num_samples_train.append(num_train_samples)
+                    num_samples_test.append(num_test_samples)
+
+                    indices = [j for j in range(curr_num_samples)]
+                    train_indices = rng.sample(indices, num_train_samples)
+                    train_blist = [False for _ in range(curr_num_samples)]
+                    for j in train_indices:
+                        train_blist[j] = True
+
+                    for j in range(curr_num_samples):
+                        if (train_blist[j]):
+                            user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
+                            user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
+                        else:
+                            user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
+                            user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
+
+            users = [data['users'][i] for i in user_indices]
+
+            all_data_train = {}
+            all_data_train['users'] = users
+            all_data_train['num_samples'] = num_samples_train
+            all_data_train['user_data'] = user_data_train
+            all_data_test = {}
+            all_data_test['users'] = users
+            all_data_test['num_samples'] = num_samples_test
+            all_data_test['user_data'] = user_data_test
+
+            if include_hierarchy:
+                all_data_train['hierarchies'] = data['hierarchies']
+                all_data_test['hierarchies'] = data['hierarchies']
+
+            file_name_train = '{}_train_{}.json'.format((f[:-5]), arg_label)
+            file_name_test = '{}_test_{}.json'.format((f[:-5]), arg_label)
+            ouf_dir_train = os.path.join(dir, setting_folder, 'train', file_name_train)
+            ouf_dir_test = os.path.join(dir, setting_folder, 'test', file_name_test)
+            logger.info("writing {}".format(file_name_train))
+            with open(ouf_dir_train, 'w') as outfile:
+                json.dump(all_data_train, outfile)
+            logger.info("writing {}".format(file_name_test))
+            with open(ouf_dir_test, 'w') as outfile:
+                json.dump(all_data_test, outfile)

+ 42 - 0
easyfl/datasets/utils/util.py

@@ -0,0 +1,42 @@
+import pickle
+
+
+def save_obj(obj, name):
+    with open(name + '.pkl', 'wb') as f:
+        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
+
+
+def load_obj(name):
+    with open(name + '.pkl', 'rb') as f:
+        return pickle.load(f)
+
+
+def save_dict(dic, filename):
+    with open(filename, 'wb') as f:
+        pickle.dump(dic, f)
+
+
+def load_dict(filename):
+    with open(filename, 'rb') as f:
+        dic = pickle.load(f)
+    return dic
+
+
+def iid_divide(l, g):
+    """
+    divide list l among g groups
+    each group has either int(len(l)/g) or int(len(l)/g)+1 elements
+    returns a list of groups
+    """
+    num_elems = len(l)
+    group_size = int(len(l) / g)
+    num_big_groups = num_elems - g * group_size
+    num_small_groups = g - num_big_groups
+    glist = []
+    for i in range(num_small_groups):
+        glist.append(l[group_size * i: group_size * (i + 1)])
+    bi = group_size * num_small_groups
+    group_size += 1
+    for i in range(num_big_groups):
+        glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
+    return glist

+ 18 - 0
easyfl/distributed/__init__.py

@@ -0,0 +1,18 @@
+from easyfl.distributed.distributed import (
+    dist_init,
+    get_device,
+    grouping,
+    reduce_models,
+    reduce_models_only_params,
+    reduce_value,
+    reduce_values,
+    reduce_weighted_values,
+    gather_value,
+    CPU
+)
+
+from easyfl.distributed.slurm import setup, get_ip
+
+__all__ = ['dist_init', 'get_device', 'grouping', 'gather_value', 'setup', 'get_ip',
+           'reduce_models', 'reduce_models_only_params', 'reduce_value', 'reduce_values', 'reduce_weighted_values']
+

+ 257 - 0
easyfl/distributed/distributed.py

@@ -0,0 +1,257 @@
+import logging
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+logger = logging.getLogger(__name__)
+
+CPU = "cpu"
+
+RANDOMIZE_GROUPING = "random"
+GREEDY_GROUPING = "greedy"
+SLOWEST_GROUPING = "slowest"
+
+
+def reduce_models(model, sample_sum):
+    """Aggregate models across devices and update the model with the new aggregated model parameters.
+
+    Args:
+        model (nn.Module): The model in a device to aggregate.
+        sample_sum (int): Sum of the total dataset sizes of clients in a device.
+    """
+    dist.all_reduce(sample_sum, op=dist.ReduceOp.SUM)
+    state = model.state_dict()
+    for k in state.keys():
+        dist.all_reduce(state[k], op=dist.ReduceOp.SUM)
+        state[k] = torch.div(state[k], sample_sum)
+    model.load_state_dict(state)
+
+
+def reduce_models_only_params(model, sample_sum):
+    """Aggregate models across devices and update the model with the new aggregated model parameters,
+    excluding the persistent buffers like BN stats.
+
+    Args:
+        model (nn.Module): The model in a device to aggregate.
+        sample_sum (torch.Tensor): Sum of the total dataset sizes of clients in a device.
+    """
+    dist.all_reduce(sample_sum, op=dist.ReduceOp.SUM)
+    for param in model.parameters():
+        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
+        param.data = torch.div(param.data, sample_sum)
+
+
+def reduce_value(value, device):
+    """Calculate the sum of the value across devices.
+
+    Args:
+        value (float/int): Value to sum.
+        device (str): The device where the value is on, either cpu or cuda devices.
+    Returns:
+         torch.Tensor: Sum of the values.
+    """
+    v = torch.tensor(value).to(device)
+    dist.all_reduce(v, op=dist.ReduceOp.SUM)
+    return v
+
+
+def reduce_values(values, device):
+    """Calculate the average of values across devices.
+
+    Args:
+        values (list[float|int]): Values to average.
+        device (str): The device where the value is on, either cpu or cuda devices.
+    Returns:
+         torch.Tensor: The average of the values across devices.
+    """
+    length = torch.tensor(len(values)).to(device)
+    total = torch.tensor(sum(values)).to(device)
+    dist.all_reduce(length, op=dist.ReduceOp.SUM)
+    dist.all_reduce(total, op=dist.ReduceOp.SUM)
+    return torch.div(total, length)
+
+
+def reduce_weighted_values(values, weights, device):
+    """Calculate the weighted average of values across devices.
+
+    Args:
+        values (list[float|int]): Values to average.
+        weights (list[float|int]): The weights to calculate weighted average.
+        device (str): The device where the value is on, either cpu or cuda devices.
+    Returns:
+         torch.Tensor: The average of values across devices.
+    """
+    values = torch.tensor(values).to(device)
+    weights = torch.tensor(weights).to(device)
+    total_weights = torch.sum(weights).to(device)
+    weighted_sum = torch.sum(values * weights).to(device)
+    dist.all_reduce(total_weights, op=dist.ReduceOp.SUM)
+    dist.all_reduce(weighted_sum, op=dist.ReduceOp.SUM)
+    return torch.div(weighted_sum, total_weights)
+
+
+def gather_value(value, world_size, device):
+    """Gather the value from devices to a list.
+
+    Args:
+        value (float|int): The value to gather.
+        world_size (int): The number of processes.
+        device (str): The device where the value is on, either cpu or cuda devices.
+    Returns:
+         list[torch.Tensor]: A list of gathered values.
+    """
+    v = torch.tensor(value).to(device)
+    target = [v.clone() for _ in range(world_size)]
+    dist.all_gather(target, v)
+    return target
+
+
+def grouping(clients, world_size, default_time=10, strategy=RANDOMIZE_GROUPING, seed=1):
+    """Divide clients into groups with different strategies.
+
+    Args:
+        clients (list[:obj:`BaseClient`]): A list of clients.
+        world_size (int): The number of processes, it represent the number of groups here.
+        default_time (float, optional): The default training time for not profiled clients.
+        strategy (str, optional): Strategy of grouping, options: random, greedy, worst.
+            When no strategy is applied, each client is a group.
+        seed (int, optional): Random seed.
+
+    Returns:
+        list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
+    """
+    np.random.seed(seed)
+    if strategy == RANDOMIZE_GROUPING:
+        return randomize_grouping(clients, world_size)
+    elif strategy == GREEDY_GROUPING:
+        return greedy_grouping(clients, world_size, default_time)
+    elif strategy == SLOWEST_GROUPING:
+        return slowest_grouping(clients, world_size)
+    else:
+        # default, no strategy applied
+        return [[client] for client in clients]
+
+
+def randomize_grouping(clients, world_size):
+    """"Randomly divide clients into groups.
+
+    Args:
+        clients (list[:obj:`BaseClient`]): A list of clients.
+        world_size (int): The number of processes, it represent the number of groups here.
+
+    Returns:
+        list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
+    """
+    num_of_clients = len(clients)
+    np.random.shuffle(clients)
+    data_per_client = num_of_clients // world_size
+    large_group_num = num_of_clients - world_size * data_per_client
+    small_group_num = world_size - large_group_num
+    grouped_clients = []
+    for i in range(small_group_num):
+        base_index = data_per_client * i
+        grouped_clients.append(clients[base_index: base_index + data_per_client])
+    small_size = data_per_client * small_group_num
+    data_per_client += 1
+    for i in range(large_group_num):
+        base_index = small_size + data_per_client * i
+        grouped_clients.append(clients[base_index: base_index + data_per_client])
+    return grouped_clients
+
+
+def greedy_grouping(clients, world_size, default_time):
+    """"Greedily allocate the clients with longest training time to the most available device.
+
+
+    Args:
+        clients (list[:obj:`BaseClient`]): A list of clients.
+        world_size (int): The number of processes, it represent the number of groups here.
+        default_time (float, optional): The default training time for not profiled clients.
+
+    Returns:
+        list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
+    """
+    round_time_estimation = [[i, c.round_time] if c.round_time != 0
+                             else [i, default_time] for i, c in enumerate(clients)]
+    round_time_estimation = sorted(round_time_estimation, reverse=True, key=lambda tup: (tup[1], tup[0]))
+    top_world_size = round_time_estimation[:world_size]
+    groups = [[clients[index]] for (index, time) in top_world_size]
+    time_sum = [time for (index, time) in top_world_size]
+    for i in round_time_estimation[world_size:]:
+        min_index = np.argmin(time_sum)
+        groups[min_index].append(clients[i[0]])
+        time_sum[min_index] += i[1]
+    return groups
+
+
+def slowest_grouping(clients, world_size):
+    """"Allocate the clients with longest training time to the most busy device.
+    Only for experiment, not practical in use.
+
+
+    Args:
+        clients (list[:obj:`BaseClient`]): A list of clients.
+        world_size (int): The number of processes, it represent the number of groups here.
+
+    Returns:
+        list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
+    """
+    num_of_clients = len(clients)
+    clients = sorted(clients, key=lambda tup: (tup.round_time, tup.cid))
+    data_per_client = num_of_clients // world_size
+    large_group_num = num_of_clients - world_size * data_per_client
+    small_group_num = world_size - large_group_num
+    grouped_clients = []
+    for i in range(small_group_num):
+        base_index = data_per_client * i
+        grouped_clients.append(clients[base_index: base_index + data_per_client])
+    small_size = data_per_client * small_group_num
+    data_per_client += 1
+    for i in range(large_group_num):
+        base_index = small_size + data_per_client * i
+        grouped_clients.append(clients[base_index: base_index + data_per_client])
+    return grouped_clients
+
+
+def dist_init(backend, init_method, world_size, rank, local_rank):
+    """Initialize PyTorch distribute.
+
+    Args:
+        backend (str or Backend): Distributed backend to use, e.g., `nccl`, `gloo`.
+        init_method (str, optional): URL specifying how to initialize the process group.
+        world_size (int, optional): Number of processes participating in the job.
+        rank (int, optional): Rank of the current process.
+        local rank (int, optional): Local rank of the current process.
+
+    Returns:
+        int: Rank of current process.
+        int: Total number of processes.
+    """
+    dist.init_process_group(backend, init_method=init_method, rank=rank, world_size=world_size)
+    assert dist.is_initialized()
+    return rank, world_size
+
+
+def get_device(gpu, world_size, local_rank):
+    """Obtain the device by checking the number of GPUs and distributed settings.
+
+    Args:
+        gpu (int): The number of requested gpu.
+        world_size (int): The number of processes.
+        local_rank (int): The local rank of the current process.
+
+    Returns:
+        str: Device to be used in PyTorch like `tensor.to(device)`.
+    """
+    if gpu > world_size:
+        logger.error("Available gpu: {}, requested gpu: {}".format(world_size, gpu))
+        raise ValueError("available number of gpu are less than requested")
+
+    # TODO: think of a better way to handle this, maybe just use one config param instead of two.
+    assert gpu == world_size
+
+    n = torch.cuda.device_count()
+
+    device_ids = list(range(n))
+    return device_ids[local_rank]

+ 64 - 0
easyfl/distributed/slurm.py

@@ -0,0 +1,64 @@
+import logging
+import os
+import re
+import socket
+
+logger = logging.getLogger(__name__)
+
+
+def setup(port=23344):
+    """Setup distributed settings of slurm.
+
+    Args:
+        port (int, optional): The port of the primary server.
+            It respectively auto-increments by 1 when the port is in-use.
+
+    Returns:
+        int: The rank of current process.
+        int: The local rank of current process.
+        int: Total number of processes.
+        str: The address of the distributed init method.
+    """
+    try:
+        rank = int(os.environ['SLURM_PROCID'])
+        local_rank = int(os.environ['SLURM_LOCALID'])
+        world_size = int(os.environ['SLURM_NTASKS'])
+        host = get_ip(os.environ['SLURM_STEP_NODELIST'])
+        while is_port_in_use(host, port):
+            port += 1
+        host_addr = 'tcp://' + host + ':' + str(port)
+    except KeyError:
+        return 0, 0, 0, ""
+    return rank, local_rank, world_size, host_addr
+
+
+def get_ip(node_list):
+    """Get the ip address of nodes.
+
+    Args:
+        node_list (str): Name of the nodes.
+
+    Returns:
+        str: The first node in the nodes.
+    """
+    if "[" not in node_list:
+        return node_list
+    r = re.search(r'([\w-]*)\[(\d*)[-+,+\d]*\]', node_list)
+    if not r:
+        return
+    base, node = r.groups()
+    return base + node
+
+
+def is_port_in_use(host, port):
+    """Check whether the port is in use.
+
+    Args:
+        host (str): Host address.
+        port (int): Port to use.
+
+    Returns:
+        bool: A flag to indicate whether the port is in use in the host.
+    """
+    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+        return s.connect_ex((host, port)) == 0

+ 0 - 0
easyfl/encryption/__init__.py


+ 1 - 0
easyfl/models/__init__.py

@@ -0,0 +1 @@
+from easyfl.models.model import BaseModel

+ 36 - 0
easyfl/models/lenet.py

@@ -0,0 +1,36 @@
+from torch import nn
+import torch.nn.functional as F
+
+from easyfl.models.model import BaseModel
+
+
+class Model(BaseModel):
+    def __init__(self):
+        super(Model, self).__init__()
+        self.conv1 = nn.Conv2d(1, 32, 5, padding=(2, 2))
+        self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
+        self.fc1 = nn.Linear(7 * 7 * 64, 2048)
+        self.fc2 = nn.Linear(2048, 62)
+        self.init_weights()
+
+    def forward(self, x):
+        x = F.relu(self.conv1(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = F.relu(self.conv2(x))
+        x = F.max_pool2d(x, 2, 2)
+        x = x.view(-1, 7 * 7 * 64)
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+
+        return x
+
+    def init_weights(self):
+        init_range = 0.1
+        self.conv1.weight.data.uniform_(-init_range, init_range)
+        self.conv1.bias.data.zero_()
+        self.conv2.weight.data.uniform_(-init_range, init_range)
+        self.conv2.bias.data.zero_()
+        self.fc1.weight.data.uniform_(-init_range, init_range)
+        self.fc1.bias.data.zero_()
+        self.fc2.weight.data.uniform_(-init_range, init_range)
+        self.fc2.bias.data.zero_()

+ 24 - 0
easyfl/models/model.py

@@ -0,0 +1,24 @@
+import importlib
+import logging
+from os import path
+
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+
+class BaseModel(nn.Module):
+    def __init__(self):
+        super(BaseModel, self).__init__()
+
+
+def load_model(model_name: str):
+    dir_path = path.dirname(path.realpath(__file__))
+    model_file = path.join(dir_path, "{}.py".format(model_name))
+    if not path.exists(model_file):
+        logger.error("Please specify a valid model.")
+    model_path = "easyfl.models.{}".format(model_name)
+    model_lib = importlib.import_module(model_path)
+    model = getattr(model_lib, "Model")
+    # TODO: maybe return the model class initiator
+    return model

+ 124 - 0
easyfl/models/resnet.py

@@ -0,0 +1,124 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models.resnet
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, self.expansion *
+                               planes, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = F.relu(self.bn2(self.conv2(out)))
+        out = self.bn3(self.conv3(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class ResNet(nn.Module):
+    """ResNet
+    Note two main differences from official pytorch version:
+    1. conv1 kernel size: pytorch version uses kernel_size=7
+    2. average pooling: pytorch version uses AdaptiveAvgPool
+    """
+
+    def __init__(self, block, num_blocks, num_classes=10):
+        super(ResNet, self).__init__()
+        self.in_planes = 64
+        self.feature_dim = 512 * block.expansion
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+        self.avgpool = nn.AvgPool2d((4, 4))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        out = self.avgpool(out)
+        out = out.view(out.size(0), -1)
+        out = self.fc(out)
+        return out
+
+
+def ResNet18(num_classes=10):
+    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
+
+
+def ResNet34(num_classes=10):
+    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)
+
+
+def ResNet50(num_classes=10):
+    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes)
+
+
+def ResNet101(num_classes=10):
+    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
+
+
+def ResNet152(num_classes=10):
+    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes)

+ 102 - 0
easyfl/models/resnet18.py

@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from easyfl.models import BaseModel
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, self.expansion *
+                               planes, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = F.relu(self.bn2(self.conv2(out)))
+        out = self.bn3(self.conv3(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class Model(BaseModel):
+    """ResNet18 model
+    Note two main differences from official pytorch version:
+    1. conv1 kernel size: pytorch version uses kernel_size=7
+    2. average pooling: pytorch version uses AdaptiveAvgPool
+    """
+
+    def __init__(self, block=BasicBlock, num_blocks=[2, 2, 2, 2], num_classes=10):
+        super(Model, self).__init__()
+        self.in_planes = 64
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+        self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        out = F.avg_pool2d(out, 4)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out

+ 107 - 0
easyfl/models/resnet50.py

@@ -0,0 +1,107 @@
+'''ResNet in PyTorch.
+For Pre-activation ResNet, see 'preact_resnet.py'.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Deep Residual Learning for Image Recognition. arXiv:1512.03385
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from easyfl.models import BaseModel
+
+
+class BasicBlock(BaseModel):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicBlock, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class Bottleneck(BaseModel):
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
+                               stride=stride, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(planes, self.expansion *
+                               planes, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes, self.expansion * planes,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(self.expansion * planes)
+            )
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = F.relu(self.bn2(self.conv2(out)))
+        out = self.bn3(self.conv3(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class Model(BaseModel):
+    def __init__(self, block=Bottleneck, num_blocks=[3, 4, 6, 3], num_classes=10):
+        super(Model, self).__init__()
+        self.in_planes = 64
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
+                               stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+        self.linear = nn.Linear(512 * block.expansion, num_classes)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = self.layer3(out)
+        out = self.layer4(out)
+        out = F.avg_pool2d(out, 4)
+        out = out.view(out.size(0), -1)
+        out = self.linear(out)
+        return out
+
+
+def ResNet50():
+    return Model(Bottleneck, [3, 4, 6, 3])

+ 36 - 0
easyfl/models/rnn.py

@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+
+from easyfl.models.model import BaseModel
+
+
+def repackage_hidden(h):
+    """Wraps hidden states in new Tensors, to detach them from their history."""
+
+    if isinstance(h, torch.Tensor):
+        return h.detach()
+    else:
+        return tuple(repackage_hidden(v) for v in h)
+
+
+class Model(BaseModel):
+    def __init__(self, embedding_dim=8, voc_size=80, lstm_unit=256, batch_first=True, n_layers=2):
+        super(Model, self).__init__()
+        self.encoder = nn.Embedding(voc_size, embedding_dim)
+        self.lstm = nn.LSTM(embedding_dim, lstm_unit, n_layers, batch_first=batch_first)
+        self.decoder = nn.Linear(lstm_unit, voc_size)
+        self.init_weights()
+
+    def forward(self, inp):
+        inp = self.encoder(inp)
+        inp, _ = self.lstm(inp)
+        # extract the last state of output for prediction
+        hidden = inp[:, -1]
+        output = self.decoder(hidden)
+        return output
+
+    def init_weights(self):
+        init_range = 0.1
+        self.encoder.weight.data.uniform_(-init_range, init_range)
+        self.decoder.bias.data.zero_()
+        self.decoder.weight.data.uniform_(-init_range, init_range)

+ 40 - 0
easyfl/models/simple_cnn.py

@@ -0,0 +1,40 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from easyfl.models import BaseModel
+
+
+class Model(BaseModel):
+    def __init__(self, channels=32):
+        super(Model, self).__init__()
+        self.num_channels = channels
+        self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1)
+        self.conv2 = nn.Conv2d(self.num_channels, self.num_channels * 2, 3, stride=1)
+        self.conv3 = nn.Conv2d(self.num_channels * 2, self.num_channels * 2, 3, stride=1)
+
+        # 2 fully connected layers to transform the output of the convolution layers to the final output
+        self.fc1 = nn.Linear(4 * 4 * self.num_channels * 2, self.num_channels * 2)
+        self.fc2 = nn.Linear(self.num_channels * 2, 10)
+
+    def forward(self, s):
+        s = self.conv1(s)  # batch_size x num_channels x 32 x 32
+
+        s = F.relu(F.max_pool2d(s, 2))  # batch_size x num_channels x 16 x 16
+
+        s = self.conv2(s)  # batch_size x num_channels*2 x 16 x 16
+
+        s = F.relu(F.max_pool2d(s, 2))  # batch_size x num_channels*2 x 8 x 8
+
+        s = self.conv3(s)  # batch_size x num_channels*2 x 8 x 8
+
+        # s = F.relu(F.max_pool2d(s, 2))                      # batch_size x num_channels*2 x 4 x 4
+
+        # flatten the output for each image
+        s = s.view(-1, 4 * 4 * self.num_channels * 2)  # batch_size x 4*4*num_channels*4
+
+        # apply 2 fully connected layers with dropout
+        s = F.relu(self.fc1(s))
+        s = self.fc2(s)  # batch_size x 10
+
+        return s

+ 65 - 0
easyfl/models/vgg9.py

@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+import math
+from easyfl.models import BaseModel
+
+cfg = {
+    'VGG9': [32, 64, 'M', 128, 128, 'M', 256, 256, 'M'],
+}
+
+
+def make_layers(cfg, batch_norm):
+    layers = []
+    in_channels = 3
+    for v in cfg:
+        if v == 'M':
+            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
+        else:
+            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
+            if batch_norm:
+                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
+
+            else:
+                layers += [conv2d, nn.ReLU(inplace=True)]
+            in_channels = v
+    return nn.Sequential(*layers)
+
+
+class Model(BaseModel):
+    def __init__(self, features=make_layers(cfg['VGG9'], batch_norm=False), num_classes=10):
+        super(Model, self).__init__()
+        self.features = features
+        self.classifier = nn.Sequential(
+            nn.Dropout(p=0.1),
+            nn.Linear(4096, 512),
+            nn.ReLU(True),
+            nn.Dropout(p=0.1),
+            nn.Linear(512, 512),
+            nn.ReLU(True),
+            nn.Linear(512, num_classes),
+        )
+        self._initialize_weights()
+
+    def forward(self, x):
+        x = self.features(x)
+        x = x.view(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.reset_parameters()
+            elif isinstance(m, nn.Linear):
+                m.weight.data.normal_(0, 0.01)
+                m.bias.data.zero_()
+
+
+def VGG9(batch_norm=False, **kwargs):
+    model = Model(make_layers(cfg['VGG9'], batch_norm), **kwargs)
+    return model

+ 0 - 0
easyfl/pb/__init__.py


+ 75 - 0
easyfl/pb/client_service_pb2.py

@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# source: easyfl/pb/client_service.proto
+"""Generated protocol buffer code."""
+from google.protobuf.internal import enum_type_wrapper
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+from easyfl.pb import common_pb2 as easyfl_dot_pb_dot_common__pb2
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1e\x65\x61syfl/pb/client_service.proto\x12\teasyfl.pb\x1a\x16\x65\x61syfl/pb/common.proto\"\x85\x01\n\x0eOperateRequest\x12&\n\x04type\x18\x01 \x01(\x0e\x32\x18.easyfl.pb.OperationType\x12\r\n\x05model\x18\x02 \x01(\x0c\x12\x12\n\ndata_index\x18\x03 \x01(\x05\x12(\n\x06\x63onfig\x18\x04 \x01(\x0b\x32\x18.easyfl.pb.OperateConfig\"\xce\x01\n\rOperateConfig\x12\x12\n\nbatch_size\x18\x01 \x01(\x05\x12\x13\n\x0blocal_epoch\x18\x02 \x01(\x05\x12\x0c\n\x04seed\x18\x03 \x01(\x03\x12\'\n\toptimizer\x18\x04 \x01(\x0b\x32\x14.easyfl.pb.Optimizer\x12\x12\n\nlocal_test\x18\x05 \x01(\x08\x12\x0f\n\x07task_id\x18\x06 \x01(\t\x12\x10\n\x08round_id\x18\x07 \x01(\x05\x12\r\n\x05track\x18\x08 \x01(\x08\x12\x17\n\x0ftest_batch_size\x18\t \x01(\x05\"7\n\tOptimizer\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\n\n\x02lr\x18\x02 \x01(\x02\x12\x10\n\x08momentum\x18\x03 \x01(\x02\"4\n\x0fOperateResponse\x12!\n\x06status\x18\x01 \x01(\x0b\x32\x11.easyfl.pb.Status*4\n\rOperationType\x12\x11\n\rOP_TYPE_TRAIN\x10\x00\x12\x10\n\x0cOP_TYPE_TEST\x10\x01\x32S\n\rClientService\x12\x42\n\x07Operate\x12\x19.easyfl.pb.OperateRequest\x1a\x1a.easyfl.pb.OperateResponse\"\x00\x62\x06proto3')
+
+_OPERATIONTYPE = DESCRIPTOR.enum_types_by_name['OperationType']
+OperationType = enum_type_wrapper.EnumTypeWrapper(_OPERATIONTYPE)
+OP_TYPE_TRAIN = 0
+OP_TYPE_TEST = 1
+
+
+_OPERATEREQUEST = DESCRIPTOR.message_types_by_name['OperateRequest']
+_OPERATECONFIG = DESCRIPTOR.message_types_by_name['OperateConfig']
+_OPTIMIZER = DESCRIPTOR.message_types_by_name['Optimizer']
+_OPERATERESPONSE = DESCRIPTOR.message_types_by_name['OperateResponse']
+OperateRequest = _reflection.GeneratedProtocolMessageType('OperateRequest', (_message.Message,), {
+  'DESCRIPTOR' : _OPERATEREQUEST,
+  '__module__' : 'easyfl.pb.client_service_pb2'
+  # @@protoc_insertion_point(class_scope:easyfl.pb.OperateRequest)
+  })
+_sym_db.RegisterMessage(OperateRequest)
+
+OperateConfig = _reflection.GeneratedProtocolMessageType('OperateConfig', (_message.Message,), {
+  'DESCRIPTOR' : _OPERATECONFIG,
+  '__module__' : 'easyfl.pb.client_service_pb2'
+  # @@protoc_insertion_point(class_scope:easyfl.pb.OperateConfig)
+  })
+_sym_db.RegisterMessage(OperateConfig)
+
+Optimizer = _reflection.GeneratedProtocolMessageType('Optimizer', (_message.Message,), {
+  'DESCRIPTOR' : _OPTIMIZER,
+  '__module__' : 'easyfl.pb.client_service_pb2'
+  # @@protoc_insertion_point(class_scope:easyfl.pb.Optimizer)
+  })
+_sym_db.RegisterMessage(Optimizer)
+
+OperateResponse = _reflection.GeneratedProtocolMessageType('OperateResponse', (_message.Message,), {
+  'DESCRIPTOR' : _OPERATERESPONSE,
+  '__module__' : 'easyfl.pb.client_service_pb2'
+  # @@protoc_insertion_point(class_scope:easyfl.pb.OperateResponse)
+  })
+_sym_db.RegisterMessage(OperateResponse)
+
+_CLIENTSERVICE = DESCRIPTOR.services_by_name['ClientService']
+if _descriptor._USE_C_DESCRIPTORS == False:
+
+  DESCRIPTOR._options = None
+  _OPERATIONTYPE._serialized_start=525
+  _OPERATIONTYPE._serialized_end=577
+  _OPERATEREQUEST._serialized_start=70
+  _OPERATEREQUEST._serialized_end=203
+  _OPERATECONFIG._serialized_start=206
+  _OPERATECONFIG._serialized_end=412
+  _OPTIMIZER._serialized_start=414
+  _OPTIMIZER._serialized_end=469
+  _OPERATERESPONSE._serialized_start=471
+  _OPERATERESPONSE._serialized_end=523
+  _CLIENTSERVICE._serialized_start=579
+  _CLIENTSERVICE._serialized_end=662
+# @@protoc_insertion_point(module_scope)

+ 66 - 0
easyfl/pb/client_service_pb2_grpc.py

@@ -0,0 +1,66 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
+from easyfl.pb import client_service_pb2 as easyfl_dot_pb_dot_client__service__pb2
+
+
+class ClientServiceStub(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def __init__(self, channel):
+        """Constructor.
+
+        Args:
+            channel: A grpc.Channel.
+        """
+        self.Operate = channel.unary_unary(
+                '/easyfl.pb.ClientService/Operate',
+                request_serializer=easyfl_dot_pb_dot_client__service__pb2.OperateRequest.SerializeToString,
+                response_deserializer=easyfl_dot_pb_dot_client__service__pb2.OperateResponse.FromString,
+                )
+
+
+class ClientServiceServicer(object):
+    """Missing associated documentation comment in .proto file."""
+
+    def Operate(self, request, context):
+        """Missing associated documentation comment in .proto file."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details('Method not implemented!')
+        raise NotImplementedError('Method not implemented!')
+
+
+def add_ClientServiceServicer_to_server(servicer, server):
+    rpc_method_handlers = {
+            'Operate': grpc.unary_unary_rpc_method_handler(
+                    servicer.Operate,
+                    request_deserializer=easyfl_dot_pb_dot_client__service__pb2.OperateRequest.FromString,
+                    response_serializer=easyfl_dot_pb_dot_client__service__pb2.OperateResponse.SerializeToString,
+            ),
+    }
+    generic_handler = grpc.method_handlers_generic_handler(
+            'easyfl.pb.ClientService', rpc_method_handlers)
+    server.add_generic_rpc_handlers((generic_handler,))
+
+
+ # This class is part of an EXPERIMENTAL API.
+class ClientService(object):
+    """Missing associated documentation comment in .proto file."""
+
+    @staticmethod
+    def Operate(request,
+            target,
+            options=(),
+            channel_credentials=None,
+            call_credentials=None,
+            insecure=False,
+            compression=None,
+            wait_for_ready=None,
+            timeout=None,
+            metadata=None):
+        return grpc.experimental.unary_unary(request, target, '/easyfl.pb.ClientService/Operate',
+            easyfl_dot_pb_dot_client__service__pb2.OperateRequest.SerializeToString,
+            easyfl_dot_pb_dot_client__service__pb2.OperateResponse.FromString,
+            options, channel_credentials,
+            insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

文件差异内容过多而无法显示
+ 17 - 0
easyfl/pb/common_pb2.py


部分文件因为文件数量过多而无法显示