slurm.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import logging
  2. import os
  3. import re
  4. import socket
  5. logger = logging.getLogger(__name__)
  6. def setup(port=23344):
  7. """Setup distributed settings of slurm.
  8. Args:
  9. port (int, optional): The port of the primary server.
  10. It respectively auto-increments by 1 when the port is in-use.
  11. Returns:
  12. int: The rank of current process.
  13. int: The local rank of current process.
  14. int: Total number of processes.
  15. str: The address of the distributed init method.
  16. """
  17. try:
  18. rank = int(os.environ['SLURM_PROCID'])
  19. local_rank = int(os.environ['SLURM_LOCALID'])
  20. world_size = int(os.environ['SLURM_NTASKS'])
  21. host = get_ip(os.environ['SLURM_STEP_NODELIST'])
  22. while is_port_in_use(host, port):
  23. port += 1
  24. host_addr = 'tcp://' + host + ':' + str(port)
  25. except KeyError:
  26. return 0, 0, 0, ""
  27. return rank, local_rank, world_size, host_addr
  28. def get_ip(node_list):
  29. """Get the ip address of nodes.
  30. Args:
  31. node_list (str): Name of the nodes.
  32. Returns:
  33. str: The first node in the nodes.
  34. """
  35. if "[" not in node_list:
  36. return node_list
  37. r = re.search(r'([\w-]*)\[(\d*)[-+,+\d]*\]', node_list)
  38. if not r:
  39. return
  40. base, node = r.groups()
  41. return base + node
  42. def is_port_in_use(host, port):
  43. """Check whether the port is in use.
  44. Args:
  45. host (str): Host address.
  46. port (int): Port to use.
  47. Returns:
  48. bool: A flag to indicate whether the port is in use in the host.
  49. """
  50. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  51. return s.connect_ex((host, port)) == 0