1
0
epfl_cs451/barrier.py

89 lines
2.1 KiB
Python
Raw Normal View History

2020-09-14 08:57:26 +02:00
#!/usr/bin/env python3
import argparse
import socket
2020-09-20 15:55:05 +02:00
class Barrier:
def __init__(self, host, port, wait_for):
self.host = host
self.port = port
self.wait_for = wait_for
2020-09-14 08:57:26 +02:00
2020-09-20 15:55:05 +02:00
def listen(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((self.host, self.port))
self.sock.listen()
def waitSingle(self):
2020-09-14 08:57:26 +02:00
connections = []
addresses = []
while True:
2020-09-20 15:55:05 +02:00
conn, addr = self.sock.accept()
2020-09-14 08:57:26 +02:00
connections.append(conn)
addresses.append(addr)
2020-09-20 15:55:05 +02:00
yield addr
2020-09-14 08:57:26 +02:00
2020-09-20 15:55:05 +02:00
if len(connections) == self.wait_for:
2020-09-14 08:57:26 +02:00
break
for conn in connections:
conn.close()
2020-09-20 15:55:05 +02:00
return None
def wait(self):
g = self.waitSingle()
conn = []
while True:
try:
conn.append(next(g))
except StopIteration:
break
return conn
2020-09-14 08:57:26 +02:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
default="0.0.0.0",
dest="host",
help="IP address where the barrier listens to (default: any)",
)
parser.add_argument(
"--port",
default=11000,
type=int,
dest="port",
help="TCP port where the barrier listens to (default: 11000)",
)
parser.add_argument(
"--processes",
required=True,
type=int,
dest="processes",
help="Number of processes the barrier waits for",
)
results = parser.parse_args()
2020-09-20 15:55:05 +02:00
barrier = Barrier(results.host, results.port, results.processes)
barrier.listen()
2020-09-14 08:57:26 +02:00
print("Barrier listens on {}:{} and waits for {} processes".format(results.host, results.port, results.processes))
2020-09-20 15:55:05 +02:00
# connectedAddr = barrier.wait()
connectedAddrGen = barrier.waitSingle()
while True:
try:
connectedAddr = next(connectedAddrGen)
print("Connection from {}".format(connectedAddr))
except StopIteration:
break