86 lines
2.3 KiB
Python
Executable File
86 lines
2.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import time
|
|
import multiprocessing
|
|
|
|
import paho.mqtt.client as mqtt
|
|
|
|
|
|
def consumer(barrier, conn, host, expected_messages, timeout):
|
|
pid = multiprocessing.current_process().pid
|
|
client = mqtt.Client(f"consumer_{pid}")
|
|
# client.username_pw_set("username", "password")
|
|
|
|
total_messages = 0
|
|
first_message_time = None
|
|
last_message_time = None
|
|
|
|
def on_message(client, userdata, msg):
|
|
nonlocal total_messages, first_message_time, last_message_time
|
|
total_messages += 1
|
|
last_message_time = time.time()
|
|
first_message_time = first_message_time or last_message_time
|
|
|
|
def on_subscribe(client, userdata, mid, granted_qos):
|
|
barrier.wait()
|
|
|
|
client.on_message = on_message
|
|
client.on_subscribe = on_subscribe
|
|
|
|
client.connect(host)
|
|
client.loop_start()
|
|
client.subscribe("benchmark")
|
|
while total_messages < expected_messages:
|
|
if first_message_time and total_messages >= 2:
|
|
elapsed_time = time.time() - first_message_time
|
|
if elapsed_time >= timeout:
|
|
break
|
|
time.sleep(1)
|
|
client.loop_stop()
|
|
|
|
elapsed_time = last_message_time - first_message_time
|
|
rate = (total_messages - 1) / elapsed_time
|
|
conn.send(rate)
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("host")
|
|
parser.add_argument("--consumers", type=int, default=1)
|
|
parser.add_argument("--messages", type=int, default=1000)
|
|
parser.add_argument("--size", type=int, default=1)
|
|
parser.add_argument("--timeout", type=int, default=10)
|
|
|
|
args = parser.parse_args()
|
|
|
|
client = mqtt.Client("producer")
|
|
client.connect(args.host)
|
|
|
|
barrier = multiprocessing.Barrier(args.consumers + 1)
|
|
pipe = multiprocessing.Pipe(False)
|
|
consumers = [
|
|
multiprocessing.Process(
|
|
target=consumer, args=(barrier, pipe[1], args.host, args.messages, args.timeout)
|
|
)
|
|
for _ in range(args.consumers)
|
|
]
|
|
|
|
message = "0" * args.size
|
|
|
|
for consumer in consumers:
|
|
consumer.start()
|
|
|
|
# wait for all processes to connect
|
|
barrier.wait()
|
|
|
|
# fire messages
|
|
while any(p.is_alive() for p in consumers):
|
|
client.publish("benchmark", message)
|
|
|
|
# wait for consumers
|
|
for consumer in consumers:
|
|
consumer.join()
|
|
|
|
# collect results
|
|
rate = sum(pipe[0].recv() for _ in consumers) / len(consumers)
|
|
print(f"{rate:.1f}")
|