12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import logging
- import os
- import re
- import socket
- logger = logging.getLogger(__name__)
- def setup(port=23344):
- """Setup distributed settings of slurm.
- Args:
- port (int, optional): The port of the primary server.
- It respectively auto-increments by 1 when the port is in-use.
- Returns:
- int: The rank of current process.
- int: The local rank of current process.
- int: Total number of processes.
- str: The address of the distributed init method.
- """
- try:
- rank = int(os.environ['SLURM_PROCID'])
- local_rank = int(os.environ['SLURM_LOCALID'])
- world_size = int(os.environ['SLURM_NTASKS'])
- host = get_ip(os.environ['SLURM_STEP_NODELIST'])
- while is_port_in_use(host, port):
- port += 1
- host_addr = 'tcp://' + host + ':' + str(port)
- except KeyError:
- return 0, 0, 0, ""
- return rank, local_rank, world_size, host_addr
- def get_ip(node_list):
- """Get the ip address of nodes.
- Args:
- node_list (str): Name of the nodes.
- Returns:
- str: The first node in the nodes.
- """
- if "[" not in node_list:
- return node_list
- r = re.search(r'([\w-]*)\[(\d*)[-+,+\d]*\]', node_list)
- if not r:
- return
- base, node = r.groups()
- return base + node
- def is_port_in_use(host, port):
- """Check whether the port is in use.
- Args:
- host (str): Host address.
- port (int): Port to use.
- Returns:
- bool: A flag to indicate whether the port is in use in the host.
- """
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex((host, port)) == 0
|