Initial commit, getting all the stuff from PlatformIO

This commit is contained in:
2025-11-02 17:55:41 +00:00
commit 4b4b816a8c
3003 changed files with 1213319 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
#pragma once
#include <Arduino.h>
#if defined(ESP8266)
#if (ARDUINO_ESP8266_MAJOR != 3) || (ARDUINO_ESP8266_MINOR < 1)
#error PicoMQTT requires ESP8266 board core version >= 3.1
#endif
#elif defined(ESP32)
#if ESP_ARDUINO_VERSION < ESP_ARDUINO_VERSION_VAL(2, 0, 7)
#error PicoMQTT requires ESP32 board core version >= 2.0.7
#endif
#endif
#include "PicoMQTT/client.h"
#include "PicoMQTT/server.h"

View File

@@ -0,0 +1,21 @@
#pragma once
namespace PicoMQTT {
class AutoId {
public:
typedef unsigned int Id;
AutoId(): id(generate_id()) {}
AutoId(const AutoId &) = default;
const Id id;
private:
static Id generate_id() {
static Id next_id = 1;
return next_id++;
}
};
}

View File

@@ -0,0 +1,290 @@
#include "client.h"
#include "debug.h"
namespace PicoMQTT {
BasicClient::BasicClient(::Client & client, unsigned long keep_alive_millis,
unsigned long socket_timeout_millis)
: Connection(client, keep_alive_millis, socket_timeout_millis) {
TRACE_FUNCTION
}
bool BasicClient::connect(
const char * host,
uint16_t port,
const char * id,
const char * user,
const char * pass,
const char * will_topic,
const char * will_message,
const size_t will_message_length,
uint8_t will_qos,
bool will_retain,
const bool clean_session,
ConnectReturnCode * connect_return_code) {
TRACE_FUNCTION
if (connect_return_code) {
*connect_return_code = CRC_UNDEFINED;
}
client.stop();
if (!client.connect(host, port)) {
return false;
}
message_id_generator.reset();
const bool will = will_topic && will_message;
const uint8_t connect_flags =
(user ? 1 : 0) << 7
| (user && pass ? 1 : 0) << 6
| (will && will_retain ? 1 : 0) << 5
| (will && will_qos ? 1 : 0) << 3
| (will ? 1 : 0) << 2
| (clean_session ? 1 : 0) << 1;
const size_t client_id_length = strlen(id);
const size_t will_topic_length = (will && will_topic) ? strlen(will_topic) : 0;
const size_t user_length = user ? strlen(user) : 0;
const size_t pass_length = pass ? strlen(pass) : 0;
const size_t total_size = 6 // protocol name
+ 1 // protocol level
+ 1 // connect flags
+ 2 // keep-alive
+ client_id_length + 2
+ (will ? will_topic_length + 2 : 0)
+ (will ? will_message_length + 2 : 0)
+ (user ? user_length + 2 : 0)
+ (user && pass ? pass_length + 2 : 0);
auto packet = build_packet(Packet::CONNECT, 0, total_size);
packet.write_string("MQTT", 4);
packet.write_u8(4);
packet.write_u8(connect_flags);
packet.write_u16(keep_alive_millis / 1000);
packet.write_string(id, client_id_length);
if (will) {
packet.write_string(will_topic, will_topic_length);
packet.write_string(will_message, will_message_length);
}
if (user) {
packet.write_string(user, user_length);
if (pass) {
packet.write_string(pass, pass_length);
}
}
if (!packet.send()) {
return false;
}
wait_for_reply(Packet::CONNACK, [this, connect_return_code](IncomingPacket & packet) {
TRACE_FUNCTION
if (packet.size != 2) {
on_protocol_violation();
return;
}
/* const uint8_t connect_ack_flags = */ packet.read_u8();
const uint8_t crc = packet.read_u8();
if (connect_return_code) {
*connect_return_code = (ConnectReturnCode) crc;
}
if (crc != 0) {
// connection refused
client.stop();
}
});
return client.connected();
}
void BasicClient::loop() {
TRACE_FUNCTION
if (client.connected() && get_millis_since_last_write() >= keep_alive_millis) {
// ping time!
build_packet(Packet::PINGREQ).send();
wait_for_reply(Packet::PINGRESP, [](IncomingPacket &) {});
}
Connection::loop();
}
Publisher::Publish BasicClient::begin_publish(const char * topic, const size_t payload_size,
uint8_t qos, bool retain, uint16_t message_id) {
TRACE_FUNCTION
return Publish(
*this,
client.connected() ? client : PrintMux(),
topic, payload_size,
(qos >= 1) ? 1 : 0,
retain,
message_id, // dup if message_id is non-zero
message_id ? message_id : message_id_generator.generate() // generate only if message_id == 0
);
}
bool BasicClient::on_publish_complete(const Publish & publish) {
TRACE_FUNCTION
if (publish.qos == 0) {
return true;
}
bool confirmed = false;
wait_for_reply(Packet::PUBACK, [&publish, &confirmed](IncomingPacket & puback) {
confirmed |= (puback.read_u16() == publish.message_id);
});
return confirmed;
}
bool BasicClient::subscribe(const String & topic, uint8_t qos, uint8_t * qos_granted) {
TRACE_FUNCTION
if (qos > 1) {
return false;
}
const size_t topic_size = topic.length();
const uint16_t message_id = message_id_generator.generate();
auto packet = build_packet(Packet::SUBSCRIBE, 0b0010, 2 + 2 + topic_size + 1);
packet.write_u16(message_id);
packet.write_string(topic.c_str(), topic_size);
packet.write_u8(qos);
packet.send();
uint8_t code = 0x80;
wait_for_reply(Packet::SUBACK, [this, message_id, &code](IncomingPacket & packet) {
if (packet.read_u16() != message_id) {
on_protocol_violation();
} else {
code = packet.read_u8();
}
});
if (code == 0x80) {
return false;
}
if (qos_granted) {
*qos_granted = code;
}
return client.connected();
}
bool BasicClient::unsubscribe(const String & topic) {
TRACE_FUNCTION
const size_t topic_size = topic.length();
const uint16_t message_id = message_id_generator.generate();
auto packet = build_packet(Packet::UNSUBSCRIBE, 0b0010, 2 + 2 + topic_size);
packet.write_u16(message_id);
packet.write_string(topic.c_str(), topic_size);
packet.send();
wait_for_reply(Packet::UNSUBACK, [this, message_id](IncomingPacket & packet) {
if (packet.read_u16() != message_id) {
on_protocol_violation();
}
});
return client.connected();
}
Client::Client(ClientSocketInterface * socket,
const char * host, uint16_t port, const char * id, const char * user, const char * password,
unsigned long reconnect_interval_millis, unsigned long keep_alive_millis, unsigned long socket_timeout_millis)
: SocketOwner<std::unique_ptr<ClientSocketInterface>>(socket),
BasicClient(this->socket->get_client(), keep_alive_millis, socket_timeout_millis),
host(host), port(port), client_id(id), username(user), password(password),
will({"", "", 0, false}),
reconnect_interval_millis(reconnect_interval_millis),
last_reconnect_attempt(millis() - reconnect_interval_millis) {
TRACE_FUNCTION
}
Client::SubscriptionId Client::subscribe(const String & topic_filter, MessageCallback callback) {
TRACE_FUNCTION
const auto ret = SubscribedMessageListener::subscribe(topic_filter, callback);
BasicClient::subscribe(topic_filter);
return ret;
}
void Client::unsubscribe(const String & topic_filter) {
TRACE_FUNCTION
BasicClient::unsubscribe(topic_filter);
SubscribedMessageListener::unsubscribe(topic_filter);
}
void Client::on_message(const char * topic, IncomingPacket & packet) {
SubscribedMessageListener::fire_message_callbacks(topic, packet);
}
void Client::loop() {
TRACE_FUNCTION
if (!client.connected()) {
if (host.isEmpty() || !port) {
return;
}
if (millis() - last_reconnect_attempt < reconnect_interval_millis) {
return;
}
const bool connection_established = connect(host.c_str(), port,
client_id.isEmpty() ? "" : client_id.c_str(),
username.isEmpty() ? nullptr : username.c_str(),
password.isEmpty() ? nullptr : password.c_str(),
will.topic.isEmpty() ? nullptr : will.topic.c_str(),
will.payload.isEmpty() ? nullptr : will.payload.c_str(),
will.payload.isEmpty() ? 0 : will.payload.length(),
will.qos, will.retain);
last_reconnect_attempt = millis();
if (!connection_established) {
if (connection_failure_callback) {
connection_failure_callback();
}
return;
}
for (const auto & kv : subscriptions) {
BasicClient::subscribe(kv.first.c_str());
}
on_connect();
}
BasicClient::loop();
}
void Client::on_connect() {
TRACE_FUNCTION
BasicClient::on_connect();
if (connected_callback) {
connected_callback();
}
}
void Client::on_disconnect() {
TRACE_FUNCTION
BasicClient::on_disconnect();
if (disconnected_callback) {
disconnected_callback();
}
}
}

View File

@@ -0,0 +1,121 @@
#pragma once
#include <Arduino.h>
#include "connection.h"
#include "incoming_packet.h"
#include "outgoing_packet.h"
#include "pico_interface.h"
#include "publisher.h"
#include "subscriber.h"
#include "utils.h"
namespace PicoMQTT {
class BasicClient: public PicoMQTTInterface, public Connection, public Publisher {
public:
BasicClient(::Client & client, unsigned long keep_alive_millis = 60 * 1000,
unsigned long socket_timeout_millis = 10 * 1000);
bool connect(
const char * host, uint16_t port = 1883,
const char * id = "", const char * user = nullptr, const char * pass = nullptr,
const char * will_topic = nullptr, const char * will_message = nullptr,
const size_t will_message_length = 0, uint8_t willQos = 0, bool willRetain = false,
const bool cleanSession = true,
ConnectReturnCode * connect_return_code = nullptr);
using Publisher::begin_publish;
virtual Publish begin_publish(const char * topic, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override;
bool subscribe(const String & topic, uint8_t qos = 0, uint8_t * qos_granted = nullptr);
bool unsubscribe(const String & topic);
void loop() override;
virtual void on_connect() {}
private:
virtual bool on_publish_complete(const Publish & publish) override;
};
class ClientSocketInterface {
public:
virtual ::Client & get_client() = 0;
virtual ~ClientSocketInterface() {}
};
class ClientSocketProxy: public ClientSocketInterface {
public:
ClientSocketProxy(::Client & client): client(client) {}
virtual ::Client & get_client() override { return client; }
::Client & client;
};
template <typename ClientType>
class ClientSocket: public ClientType, public ClientSocketInterface {
public:
using ClientType::ClientType;
virtual ::Client & get_client() override { return *this; }
};
class Client: public SocketOwner<std::unique_ptr<ClientSocketInterface>>, public BasicClient,
public SubscribedMessageListener {
public:
Client(const char * host = nullptr, uint16_t port = 1883, const char * id = nullptr, const char * user = nullptr,
const char * password = nullptr, unsigned long reconnect_interval_millis = 5 * 1000,
unsigned long keep_alive_millis = 60 * 1000, unsigned long socket_timeout_millis = 10 * 1000)
: Client(new ClientSocket<::WiFiClient>(), host, port, id, user, password, reconnect_interval_millis, keep_alive_millis,
socket_timeout_millis) {
}
template <typename ClientType>
Client(ClientType & client, const char * host = nullptr, uint16_t port = 1883, const char * id = nullptr,
const char * user = nullptr, const char * password = nullptr,
unsigned long reconnect_interval_millis = 5 * 1000,
unsigned long keep_alive_millis = 60 * 1000, unsigned long socket_timeout_millis = 10 * 1000)
: Client(new ClientSocketProxy(client), host, port, id, user, password, reconnect_interval_millis, keep_alive_millis,
socket_timeout_millis) {
}
using SubscribedMessageListener::subscribe;
virtual SubscriptionId subscribe(const String & topic_filter, MessageCallback callback) override;
virtual void unsubscribe(const String & topic_filter) override;
virtual void loop() override;
String host;
uint16_t port;
String client_id;
String username;
String password;
struct {
String topic;
String payload;
uint8_t qos;
bool retain;
} will;
unsigned long reconnect_interval_millis;
std::function<void()> connected_callback;
std::function<void()> disconnected_callback;
std::function<void()> connection_failure_callback;
virtual void on_connect() override;
virtual void on_disconnect() override;
protected:
Client(ClientSocketInterface * client,
const char * host, uint16_t port, const char * id, const char * user, const char * password,
unsigned long reconnect_interval_millis, unsigned long keep_alive_millis, unsigned long socket_timeout_millis);
unsigned long last_reconnect_attempt;
virtual void on_message(const char * topic, IncomingPacket & packet) override;
};
}

View File

@@ -0,0 +1,163 @@
#include "Arduino.h"
#include "client_wrapper.h"
#include "debug.h"
namespace PicoMQTT {
ClientWrapper::ClientWrapper(::Client & client, unsigned long socket_timeout_millis):
socket_timeout_millis(socket_timeout_millis), client(client) {
TRACE_FUNCTION
}
// reads
int ClientWrapper::available_wait(unsigned long timeout) {
TRACE_FUNCTION
const unsigned long start_millis = millis();
while (true) {
const int ret = available();
if (ret > 0) {
return ret;
}
if (!connected()) {
// A disconnected client might still have unread data waiting in buffers. Don't move this check earlier.
return 0;
}
const unsigned long elapsed = millis() - start_millis;
if (elapsed > timeout) {
return 0;
}
yield();
}
}
int ClientWrapper::read(uint8_t * buf, size_t size) {
TRACE_FUNCTION
const unsigned long start_millis = millis();
size_t ret = 0;
while (ret < size) {
const unsigned long now_millis = millis();
const unsigned long elapsed_millis = now_millis - start_millis;
if (elapsed_millis > socket_timeout_millis) {
// timeout
abort();
break;
}
const unsigned long remaining_millis = socket_timeout_millis - elapsed_millis;
const int available_size = available_wait(remaining_millis);
if (available_size <= 0) {
// timeout
abort();
break;
}
const int chunk_size = size - ret < (size_t) available_size ? size - ret : (size_t) available_size;
const int bytes_read = client.read(buf + ret, chunk_size);
if (bytes_read <= 0) {
// connection error
abort();
break;
}
ret += bytes_read;
}
return ret;
}
int ClientWrapper::read() {
TRACE_FUNCTION
if (!available_wait(socket_timeout_millis)) {
return -1;
}
return client.read();
}
int ClientWrapper::peek() {
TRACE_FUNCTION
if (!available_wait(socket_timeout_millis)) {
return -1;
}
return client.peek();
}
// writes
size_t ClientWrapper::write(const uint8_t * buffer, size_t size) {
TRACE_FUNCTION
size_t ret = 0;
while (connected() && ret < size) {
const int bytes_written = client.write(buffer + ret, size - ret);
if (bytes_written <= 0) {
// connection error
abort();
return 0;
}
ret += bytes_written;
}
return ret;
}
size_t ClientWrapper::write(uint8_t value) {
TRACE_FUNCTION
return write(&value, 1);
}
// simple wrappers forwarding requests to this->client
int ClientWrapper::connect(IPAddress ip, uint16_t port) {
TRACE_FUNCTION
return client.connect(ip, port);
}
int ClientWrapper::connect(const char * host, uint16_t port) {
TRACE_FUNCTION
return client.connect(host, port);
}
#ifdef PICOMQTT_EXTRA_CONNECT_METHODS
int ClientWrapper::connect(IPAddress ip, uint16_t port, int32_t timeout) {
TRACE_FUNCTION
return client.connect(ip, port, timeout);
}
int ClientWrapper::connect(const char * host, uint16_t port, int32_t timeout) {
TRACE_FUNCTION
return client.connect(host, port, timeout);
}
#endif
int ClientWrapper::available() {
TRACE_FUNCTION
return client.available();
}
void ClientWrapper::flush() {
TRACE_FUNCTION
client.flush();
}
void ClientWrapper::stop() {
TRACE_FUNCTION
client.stop();
}
uint8_t ClientWrapper::connected() {
TRACE_FUNCTION
return client.connected();
}
ClientWrapper::operator bool() {
return bool(client);
}
}

View File

@@ -0,0 +1,46 @@
#pragma once
#include <WiFiClient.h>
#include "config.h"
namespace PicoMQTT {
class ClientWrapper: public ::Client {
public:
ClientWrapper(::Client & client, unsigned long socket_timeout_millis);
ClientWrapper(const ClientWrapper &) = default;
virtual int peek() override;
virtual int read() override;
virtual int read(uint8_t * buf, size_t size) override;
virtual size_t write(const uint8_t * buffer, size_t size) override;
virtual size_t write(uint8_t value) override final;
// all of the below call the corresponding method on this->client
virtual int connect(IPAddress ip, uint16_t port) override;
virtual int connect(const char * host, uint16_t port) override;
#ifdef PICOMQTT_EXTRA_CONNECT_METHODS
virtual int connect(IPAddress ip, uint16_t port, int32_t timeout) override;
virtual int connect(const char * host, uint16_t port, int32_t timeout) override;
#endif
virtual int available() override;
virtual void flush() override;
virtual void stop() override;
virtual uint8_t connected() override;
virtual operator bool() override;
const unsigned long socket_timeout_millis;
void abort() {
// TODO: Use client.abort() if client is a WiFiClient on ESP8266?
stop();
}
protected:
::Client & client;
int available_wait(unsigned long timeout);
};
}

View File

@@ -0,0 +1,37 @@
#pragma once
#include <Arduino.h>
#ifndef PICOMQTT_MAX_TOPIC_SIZE
#define PICOMQTT_MAX_TOPIC_SIZE 256
#endif
#ifndef PICOMQTT_MAX_MESSAGE_SIZE
#define PICOMQTT_MAX_MESSAGE_SIZE 1024
#endif
#ifndef PICOMQTT_MAX_CLIENT_ID_SIZE
/*
* The MQTT standard requires brokers to accept client ids that are
* 1-23 chars long, but allows longer client IDs to be accepted too.
*/
#define PICOMQTT_MAX_CLIENT_ID_SIZE 64
#endif
#ifndef PICOMQTT_MAX_USERPASS_SIZE
#define PICOMQTT_MAX_USERPASS_SIZE 256
#endif
#ifndef PICOMQTT_OUTGOING_BUFFER_SIZE
#define PICOMQTT_OUTGOING_BUFFER_SIZE 128
#endif
#ifdef ESP32
// Uncomment this define to make PicoMQTT compatible with framework variants
// which have extra Client::connect methods which accept a timeout parameter.
// #define PICOMQTT_EXTRA_CONNECT_METHODS
#endif
// #define PICOMQTT_DEBUG
// #define PICOMQTT_DEBUG_TRACE_FUNCTIONS

View File

@@ -0,0 +1,168 @@
#include "config.h"
#include "connection.h"
#include "debug.h"
namespace PicoMQTT {
Connection::Connection(::Client & client, unsigned long keep_alive_millis, unsigned long socket_timeout_millis) :
client(client, socket_timeout_millis),
keep_alive_millis(keep_alive_millis),
last_read(millis()), last_write(millis()) {
TRACE_FUNCTION
}
OutgoingPacket Connection::build_packet(Packet::Type type, uint8_t flags, size_t length) {
TRACE_FUNCTION
last_write = millis();
auto ret = OutgoingPacket(client, type, flags, length);
ret.write_header();
return ret;
}
void Connection::on_timeout() {
TRACE_FUNCTION
client.abort();
on_disconnect();
}
void Connection::on_protocol_violation() {
TRACE_FUNCTION
on_disconnect();
}
void Connection::on_disconnect() {
TRACE_FUNCTION
client.stop();
}
void Connection::disconnect() {
TRACE_FUNCTION
build_packet(Packet::DISCONNECT).send();
client.stop();
}
bool Connection::connected() {
TRACE_FUNCTION
return client.connected();
}
void Connection::wait_for_reply(Packet::Type type, std::function<void(IncomingPacket & packet)> handler) {
TRACE_FUNCTION
const unsigned long start = millis();
while (client.connected() && (millis() - start < client.socket_timeout_millis)) {
IncomingPacket packet(client);
if (!packet) {
break;
}
last_read = millis();
if (packet.get_type() == type) {
handler(packet);
return;
}
handle_packet(packet);
}
if (client.connected()) {
on_timeout();
}
}
void Connection::send_ack(Packet::Type ack_type, uint16_t msg_id) {
TRACE_FUNCTION
auto ack = build_packet(ack_type, 0, 2);
ack.write_u16(msg_id);
ack.send();
}
void Connection::handle_packet(IncomingPacket & packet) {
TRACE_FUNCTION
switch (packet.get_type()) {
case Packet::PUBLISH: {
const uint16_t topic_size = packet.read_u16();
// const bool dup = (packet.get_flags() >> 3) & 0b1;
const uint8_t qos = (packet.get_flags() >> 1) & 0b11;
// const bool retain = packet.get_flags() & 0b1;
uint16_t msg_id = 0;
if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) {
packet.ignore(topic_size);
on_topic_too_long(packet);
if (qos) {
msg_id = packet.read_u16();
}
} else {
char topic[topic_size + 1];
if (!packet.read_string(topic, topic_size)) {
// connection error
return;
}
if (qos) {
msg_id = packet.read_u16();
}
on_message(topic, packet);
}
if (msg_id) {
send_ack(qos == 1 ? Packet::PUBACK : Packet::PUBREC, msg_id);
}
break;
};
case Packet::PUBREC:
send_ack(Packet::PUBREL, packet.read_u16());
break;
case Packet::PUBREL:
send_ack(Packet::PUBCOMP, packet.read_u16());
break;
case Packet::PUBCOMP:
// ignore
break;
case Packet::DISCONNECT:
on_disconnect();
break;
default:
on_protocol_violation();
break;
}
}
unsigned long Connection::get_millis_since_last_read() const {
TRACE_FUNCTION
return millis() - last_read;
}
unsigned long Connection::get_millis_since_last_write() const {
TRACE_FUNCTION
return millis() - last_write;
}
void Connection::loop() {
TRACE_FUNCTION
// only handle 10 packets max in one go to not starve other connections
for (unsigned int i = 0; (i < 10) && client.available(); ++i) {
IncomingPacket packet(client);
if (!packet.is_valid()) {
return;
}
last_read = millis();
handle_packet(packet);
}
}
}

View File

@@ -0,0 +1,79 @@
#pragma once
#include <functional>
#include <memory>
#include <Arduino.h>
#include "client_wrapper.h"
#include "incoming_packet.h"
#include "outgoing_packet.h"
namespace PicoMQTT {
enum ConnectReturnCode : uint8_t {
CRC_ACCEPTED = 0,
CRC_UNACCEPTABLE_PROTOCOL_VERSION = 1,
CRC_IDENTIFIER_REJECTED = 2,
CRC_SERVER_UNAVAILABLE = 3,
CRC_BAD_USERNAME_OR_PASSWORD = 4,
CRC_NOT_AUTHORIZED = 5,
// internal
CRC_UNDEFINED = 255,
};
class Connection {
public:
Connection(::Client & client, unsigned long keep_alive_millis = 0,
unsigned long socket_timeout_millis = 15 * 1000);
virtual ~Connection() {}
bool connected();
void disconnect();
virtual void loop();
protected:
class MessageIdGenerator {
public:
MessageIdGenerator(): value(0) {}
uint16_t generate() {
if (++value == 0) { value = 1; }
return value;
}
void reset() { value = 0; }
protected:
uint16_t value;
} message_id_generator;
OutgoingPacket build_packet(Packet::Type type, uint8_t flags = 0, size_t length = 0);
void wait_for_reply(Packet::Type type, std::function<void(IncomingPacket & packet)> handler);
virtual void on_topic_too_long(const IncomingPacket & packet) {}
virtual void on_message(const char * topic, IncomingPacket & packet) {}
virtual void on_timeout();
virtual void on_protocol_violation();
virtual void on_disconnect();
ClientWrapper client;
unsigned long keep_alive_millis;
virtual void handle_packet(IncomingPacket & packet);
protected:
unsigned long get_millis_since_last_read() const;
unsigned long get_millis_since_last_write() const;
private:
unsigned long last_read;
unsigned long last_write;
void send_ack(Packet::Type ack_type, uint16_t msg_id);
};
}

View File

@@ -0,0 +1,50 @@
#pragma once
#include "config.h"
#ifdef PICOMQTT_DEBUG_TRACE_FUNCTIONS
#include <Arduino.h>
namespace PicoMQTT {
class FunctionTracer {
public:
FunctionTracer(const char * function_name) : function_name(function_name) {
indent(1);
Serial.print(F("CALL "));
Serial.println(function_name);
}
~FunctionTracer() {
indent(-1);
Serial.print(F("RETURN "));
Serial.println(function_name);
}
const char * const function_name;
protected:
void indent(int delta) {
static int depth = 0;
if (delta < 0) {
depth += delta;
}
for (int i = 0; i < depth; ++i) {
Serial.print(" ");
}
if (delta > 0) {
depth += delta;
}
}
};
}
#define TRACE_FUNCTION PicoMQTT::FunctionTracer _function_tracer(__PRETTY_FUNCTION__);
#else
#define TRACE_FUNCTION
#endif

View File

@@ -0,0 +1,189 @@
#include "incoming_packet.h"
#include "debug.h"
namespace PicoMQTT {
IncomingPacket::IncomingPacket(Client & client)
: Packet(read_header(client)), client(client) {
TRACE_FUNCTION
}
IncomingPacket::IncomingPacket(IncomingPacket && other)
: Packet(other), client(other.client) {
TRACE_FUNCTION
other.pos = size;
}
IncomingPacket::IncomingPacket(const Type type, const uint8_t flags, const size_t size, Client & client)
: Packet(type, flags, size), client(client) {
TRACE_FUNCTION
}
IncomingPacket::~IncomingPacket() {
TRACE_FUNCTION
#ifdef PICOMQTT_DEBUG
if (pos != size) {
Serial.print(F("IncomingPacket read incorrect number of bytes: "));
Serial.print(pos);
Serial.print(F("/"));
Serial.println(size);
}
#endif
// read and ignore remaining data
while (get_remaining_size() && (read() >= 0));
}
// disabled functions
int IncomingPacket::connect(IPAddress ip, uint16_t port) {
TRACE_FUNCTION;
return 0;
}
int IncomingPacket::connect(const char * host, uint16_t port) {
TRACE_FUNCTION;
return 0;
}
#ifdef PICOMQTT_EXTRA_CONNECT_METHODS
int IncomingPacket::connect(IPAddress ip, uint16_t port, int32_t timeout) {
TRACE_FUNCTION;
return 0;
}
int IncomingPacket::connect(const char * host, uint16_t port, int32_t timeout) {
TRACE_FUNCTION;
return 0;
}
#endif
size_t IncomingPacket::write(const uint8_t * buffer, size_t size) {
TRACE_FUNCTION
return 0;
}
size_t IncomingPacket::write(uint8_t value) {
TRACE_FUNCTION
return 0;
}
void IncomingPacket::flush() {
TRACE_FUNCTION
}
void IncomingPacket::stop() {
TRACE_FUNCTION
}
// extended functions
int IncomingPacket::available() {
TRACE_FUNCTION;
return get_remaining_size();
}
int IncomingPacket::peek() {
TRACE_FUNCTION
if (!get_remaining_size()) {
#if PICOMQTT_DEBUG
Serial.println(F("Attempt to peek beyond end of IncomingPacket."));
#endif
return -1;
}
return client.peek();
}
int IncomingPacket::read() {
TRACE_FUNCTION
if (!get_remaining_size()) {
#if PICOMQTT_DEBUG
Serial.println(F("Attempt to read beyond end of IncomingPacket."));
#endif
return -1;
}
const int ret = client.read();
if (ret >= 0) {
++pos;
}
return ret;
}
int IncomingPacket::read(uint8_t * buf, size_t size) {
TRACE_FUNCTION
const size_t remaining = get_remaining_size();
const size_t read_size = remaining < size ? remaining : size;
#if PICOMQTT_DEBUG
if (size > remaining) {
Serial.println(F("Attempt to read buf beyond end of IncomingPacket."));
}
#endif
const int ret = client.read(buf, read_size);
if (ret > 0) {
pos += ret;
}
return ret;
}
IncomingPacket::operator bool() {
TRACE_FUNCTION
return is_valid() && bool(client);
}
uint8_t IncomingPacket::connected() {
TRACE_FUNCTION
return is_valid() && client.connected();
}
// extra functions
uint8_t IncomingPacket::read_u8() {
TRACE_FUNCTION;
return get_remaining_size() ? read() : 0;
}
uint16_t IncomingPacket::read_u16() {
TRACE_FUNCTION;
return ((uint16_t) read_u8()) << 8 | ((uint16_t) read_u8());
}
bool IncomingPacket::read_string(char * buffer, size_t len) {
if (read((uint8_t *) buffer, len) != (int) len) {
return false;
}
buffer[len] = '\0';
return true;
}
void IncomingPacket::ignore(size_t len) {
while (len--) {
read();
}
}
Packet IncomingPacket::read_header(Client & client) {
TRACE_FUNCTION
const int head = client.read();
if (head <= 0) {
return Packet();
}
uint32_t size = 0;
for (size_t length_size = 0; ; ++length_size) {
if (length_size >= 5) {
return Packet();
}
const int digit = client.read();
if (digit < 0) {
return Packet();
}
size |= (digit & 0x7f) << (7 * length_size);
if (!(digit & 0x80)) {
break;
}
}
return Packet(head, size);
}
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include <Arduino.h>
#include <Client.h>
#include "config.h"
#include "packet.h"
namespace PicoMQTT {
class IncomingPacket: public Packet, public Client {
public:
IncomingPacket(Client & client);
IncomingPacket(const Type type, const uint8_t flags, const size_t size, Client & client);
IncomingPacket(IncomingPacket &&);
IncomingPacket(const IncomingPacket &) = delete;
const IncomingPacket & operator=(const IncomingPacket &) = delete;
~IncomingPacket();
virtual int available() override;
virtual int connect(IPAddress ip, uint16_t port) override;
virtual int connect(const char * host, uint16_t port) override;
#ifdef PICOMQTT_EXTRA_CONNECT_METHODS
virtual int connect(IPAddress ip, uint16_t port, int32_t timeout) override;
virtual int connect(const char * host, uint16_t port, int32_t timeout) override;
#endif
virtual int peek() override;
virtual int read() override;
virtual int read(uint8_t * buf, size_t size) override;
// This operator is not marked explicit in the Client base class. Still, we're marking it explicit here
// to block implicit conversions to integer types.
virtual explicit operator bool() override;
virtual size_t write(const uint8_t * buffer, size_t size) override;
virtual size_t write(uint8_t value) override final;
virtual uint8_t connected() override;
virtual void flush() override;
virtual void stop() override;
uint8_t read_u8();
uint16_t read_u16();
bool read_string(char * buffer, size_t len);
void ignore(size_t len);
protected:
static Packet read_header(Client & client);
Client & client;
};
}

View File

@@ -0,0 +1,225 @@
#include <Client.h>
#include <Print.h>
#include "debug.h"
#include "outgoing_packet.h"
namespace PicoMQTT {
OutgoingPacket::OutgoingPacket(Print & print, Packet::Type type, uint8_t flags, size_t payload_size)
: Packet(type, flags, payload_size), print(print),
#ifndef PICOMQTT_UNBUFFERED
buffer_position(0),
#endif
state(State::ok) {
TRACE_FUNCTION
}
OutgoingPacket::OutgoingPacket(OutgoingPacket && other)
: OutgoingPacket(other) {
TRACE_FUNCTION
other.state = State::dead;
}
OutgoingPacket::~OutgoingPacket() {
TRACE_FUNCTION
#ifdef PICOMQTT_DEBUG
#ifndef PICOMQTT_UNBUFFERED
if (buffer_position) {
Serial.printf("OutgoingPacket has unsent data in the buffer (pos=%u)\n", buffer_position);
}
#endif
switch (state) {
case State::ok:
Serial.println(F("Unsent OutgoingPacket"));
break;
case State::sent:
if (pos != size) {
Serial.print(F("OutgoingPacket sent incorrect number of bytes: "));
Serial.print(pos);
Serial.print(F("/"));
Serial.println(size);
}
break;
default:
break;
}
#endif
}
size_t OutgoingPacket::write_from_client(::Client & client, size_t length) {
TRACE_FUNCTION
size_t written = 0;
#ifndef PICOMQTT_UNBUFFERED
while (written < length) {
const size_t remaining = length - written;
const size_t remaining_buffer_space = PICOMQTT_OUTGOING_BUFFER_SIZE - buffer_position;
const size_t chunk_size = remaining < remaining_buffer_space ? remaining : remaining_buffer_space;
const int read_size = client.read(buffer + buffer_position, chunk_size);
if (read_size <= 0) {
break;
}
buffer_position += (size_t) read_size;
written += (size_t) read_size;
if (buffer_position >= PICOMQTT_OUTGOING_BUFFER_SIZE) {
flush();
}
}
#else
uint8_t buffer[128] __attribute__((aligned(4)));
while (written < length) {
const size_t remain = length - written;
const size_t chunk_size = sizeof(buffer) < remain ? sizeof(buffer) : remain;
const int read_size = client.read(buffer, chunk_size);
if (read_size <= 0) {
break;
}
const size_t write_size = print.write(buffer, read_size);
written += write_size;
if (!write_size) {
break;
}
}
#endif
pos += written;
return written;
}
size_t OutgoingPacket::write_zero(size_t length) {
TRACE_FUNCTION
for (size_t written = 0; written < length; ++written) {
write_u8('0');
}
return length;
}
#ifndef PICOMQTT_UNBUFFERED
size_t OutgoingPacket::write(const void * data, size_t remaining, void * (*memcpy_fn)(void *, const void *, size_t n)) {
TRACE_FUNCTION
const char * src = (const char *) data;
while (remaining) {
const size_t remaining_buffer_space = PICOMQTT_OUTGOING_BUFFER_SIZE - buffer_position;
const size_t chunk_size = remaining < remaining_buffer_space ? remaining : remaining_buffer_space;
memcpy_fn(buffer + buffer_position, src, chunk_size);
buffer_position += chunk_size;
src += chunk_size;
remaining -= chunk_size;
if (buffer_position >= PICOMQTT_OUTGOING_BUFFER_SIZE) {
flush();
}
}
const size_t written = src - (const char *) data;
pos += written;
return written;
}
#endif
size_t OutgoingPacket::write(const uint8_t * data, size_t length) {
TRACE_FUNCTION
#ifndef PICOMQTT_UNBUFFERED
return write(data, length, memcpy);
#else
const size_t written = print.write(data, length);
pos += written;
return written;
#endif
}
size_t OutgoingPacket::write_P(PGM_P data, size_t length) {
TRACE_FUNCTION
#ifndef PICOMQTT_UNBUFFERED
return write(data, length, memcpy_P);
#else
// here we will need a buffer
uint8_t buffer[128] __attribute__((aligned(4)));
size_t written = 0;
while (written < length) {
const size_t remain = length - written;
const size_t chunk_size = sizeof(buffer) < remain ? sizeof(buffer) : remain;
memcpy_P(buffer, data, chunk_size);
const size_t write_size = print.write(buffer, chunk_size);
written += write_size;
data += write_size;
if (!write_size) {
break;
}
}
pos += written;
return written;
#endif
}
size_t OutgoingPacket::write_u8(uint8_t c) {
TRACE_FUNCTION
return write(&c, 1);
}
size_t OutgoingPacket::write_u16(uint16_t value) {
TRACE_FUNCTION
return write_u8(value >> 8) + write_u8(value & 0xff);
}
size_t OutgoingPacket::write_string(const char * string, uint16_t size) {
TRACE_FUNCTION
return write_u16(size) + write((const uint8_t *) string, size);
}
size_t OutgoingPacket::write_packet_length(size_t length) {
TRACE_FUNCTION
size_t ret = 0;
do {
const uint8_t digit = length & 127; // digit := length % 128
length >>= 7; // length := length / 128
ret += write_u8(digit | (length ? 0x80 : 0));
} while (length);
return ret;
}
size_t OutgoingPacket::write_header() {
TRACE_FUNCTION
const size_t ret = write_u8(head) + write_packet_length(size);
// we've just written the header, payload starts now
pos = 0;
return ret;
}
void OutgoingPacket::flush() {
TRACE_FUNCTION
#ifndef PICOMQTT_UNBUFFERED
print.write(buffer, buffer_position);
buffer_position = 0;
#endif
}
bool OutgoingPacket::send() {
TRACE_FUNCTION
const size_t remaining_size = get_remaining_size();
if (remaining_size) {
#ifdef PICOMQTT_DEBUG
Serial.printf("OutgoingPacket sent called on incomplete payload (%u / %u), filling with zeros.\n", pos, size);
#endif
write_zero(remaining_size);
}
flush();
switch (state) {
case State::ok:
// print.flush();
state = State::sent;
__attribute__((fallthrough));
case State::sent:
return true;
default:
return false;
}
}
}

View File

@@ -0,0 +1,64 @@
#pragma once
// #define MQTT_OUTGOING_PACKET_DEBUG
#include <Arduino.h>
#include "config.h"
#include "packet.h"
class Print;
class Client;
#if PICOMQTT_OUTGOING_BUFFER_SIZE == 0
#define PICOMQTT_UNBUFFERED
#endif
namespace PicoMQTT {
class OutgoingPacket: public Packet, public Print {
public:
OutgoingPacket(Print & print, Type type, uint8_t flags, size_t payload_size);
virtual ~OutgoingPacket();
const OutgoingPacket & operator=(const OutgoingPacket &) = delete;
OutgoingPacket(OutgoingPacket && other);
virtual size_t write(const uint8_t * data, size_t length) override;
virtual size_t write(uint8_t value) override final { return write(&value, 1); }
size_t write_P(PGM_P data, size_t length);
size_t write_u8(uint8_t value);
size_t write_u16(uint16_t value);
size_t write_string(const char * string, uint16_t size);
size_t write_header();
size_t write_from_client(::Client & c, size_t length);
size_t write_zero(size_t count);
virtual void flush() override;
virtual bool send();
protected:
OutgoingPacket(const OutgoingPacket &) = default;
size_t write(const void * data, size_t length, void * (*memcpy_fn)(void *, const void *, size_t n));
size_t write_packet_length(size_t length);
Print & print;
#ifndef PICOMQTT_UNBUFFERED
uint8_t buffer[PICOMQTT_OUTGOING_BUFFER_SIZE] __attribute__((aligned(4)));
size_t buffer_position;
#endif
enum class State {
ok,
sent,
error,
dead,
} state;
};
}

View File

@@ -0,0 +1,49 @@
#pragma once
#include <Arduino.h>
namespace PicoMQTT {
class Packet {
public:
enum Type : uint8_t {
ERROR = 0,
CONNECT = 1 << 4, // Client request to connect to Server
CONNACK = 2 << 4, // Connect Acknowledgment
PUBLISH = 3 << 4, // Publish message
PUBACK = 4 << 4, // Publish Acknowledgment
PUBREC = 5 << 4, // Publish Received (assured delivery part 1)
PUBREL = 6 << 4, // Publish Release (assured delivery part 2)
PUBCOMP = 7 << 4, // Publish Complete (assured delivery part 3)
SUBSCRIBE = 8 << 4, // Client Subscribe request
SUBACK = 9 << 4, // Subscribe Acknowledgment
UNSUBSCRIBE = 10 << 4, // Client Unsubscribe request
UNSUBACK = 11 << 4, // Unsubscribe Acknowledgment
PINGREQ = 12 << 4, // PING Request
PINGRESP = 13 << 4, // PING Response
DISCONNECT = 14 << 4, // Client is Disconnecting
};
Packet(uint8_t head, size_t size)
: head(head), size(size), pos(0) {}
Packet(Type type = ERROR, const uint8_t flags = 0, size_t size = 0)
: Packet((uint8_t) type | (flags & 0xf), size) {
}
virtual ~Packet() {}
Type get_type() const { return Type(head & 0xf0); }
uint8_t get_flags() const { return head & 0x0f; }
bool is_valid() { return get_type() != ERROR; }
size_t get_remaining_size() const { return pos < size ? size - pos : 0; }
const uint8_t head;
const size_t size;
protected:
size_t pos;
};
}

View File

@@ -0,0 +1,13 @@
#pragma once
namespace PicoMQTT {
class PicoMQTTInterface {
public:
virtual ~PicoMQTTInterface() {}
virtual void begin() {}
virtual void stop() {}
virtual void loop() {}
};
}

View File

@@ -0,0 +1,29 @@
#include "print_mux.h"
#include "debug.h"
namespace PicoMQTT {
size_t PrintMux::write(uint8_t value) {
TRACE_FUNCTION
for (auto print_ptr : prints) {
print_ptr->write(value);
}
return 1;
}
size_t PrintMux::write(const uint8_t * buffer, size_t size) {
TRACE_FUNCTION
for (auto print_ptr : prints) {
print_ptr->write(buffer, size);
}
return size;
}
void PrintMux::flush() {
TRACE_FUNCTION
for (auto print_ptr : prints) {
print_ptr->flush();
}
}
}

View File

@@ -0,0 +1,29 @@
#pragma once
#include <vector>
#include <Arduino.h>
namespace PicoMQTT {
class PrintMux: public ::Print {
public:
PrintMux() {}
PrintMux(Print & print) : prints({&print}) {}
void add(Print & print) {
prints.push_back(&print);
}
virtual size_t write(uint8_t) override;
virtual size_t write(const uint8_t * buffer, size_t size) override;
virtual void flush();
size_t size() const { return prints.size(); }
protected:
std::vector<Print *> prints;
};
}

View File

@@ -0,0 +1,56 @@
#include "publisher.h"
#include "debug.h"
namespace PicoMQTT {
Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print,
uint8_t flags, size_t total_size,
const char * topic, size_t topic_size,
uint16_t message_id)
:
OutgoingPacket(this->print, Packet::PUBLISH, flags, total_size),
qos((flags >> 1) & 0b11),
message_id(message_id),
print(print),
publisher(publisher) {
TRACE_FUNCTION
OutgoingPacket::write_header();
write_string(topic, topic_size);
if (qos) {
write_u16(message_id);
}
}
Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print,
const char * topic, size_t topic_size, size_t payload_size,
uint8_t qos, bool retain, bool dup, uint16_t message_id)
: Publish(
publisher, print,
(dup ? 0b1000 : 0) | ((qos & 0b11) << 1) | (retain ? 1 : 0), // flags
2 + topic_size + (qos ? 2 : 0) + payload_size, // total size
topic, topic_size, // topic
message_id) {
TRACE_FUNCTION
}
Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print,
const char * topic, size_t payload_size,
uint8_t qos, bool retain, bool dup, uint16_t message_id)
: Publish(
publisher, print,
topic, strlen(topic), payload_size,
qos, retain, dup, message_id) {
TRACE_FUNCTION
}
Publisher::Publish::~Publish() {
TRACE_FUNCTION
}
bool Publisher::Publish::send() {
TRACE_FUNCTION
return OutgoingPacket::send() && publisher.on_publish_complete(*this);
}
}

View File

@@ -0,0 +1,103 @@
#pragma once
#include <cstring>
#include <Arduino.h>
#include "debug.h"
#include "outgoing_packet.h"
#include "print_mux.h"
namespace PicoMQTT {
class Publisher {
public:
class Publish: public OutgoingPacket {
private:
Publish(Publisher & publisher, const PrintMux & print,
uint8_t flags, size_t total_size,
const char * topic, size_t topic_size,
uint16_t message_id);
public:
Publish(Publisher & publisher, const PrintMux & print,
const char * topic, size_t topic_size, size_t payload_size,
uint8_t qos = 0, bool retain = false, bool dup = false, uint16_t message_id = 0);
Publish(Publisher & publisher, const PrintMux & print,
const char * topic, size_t payload_size,
uint8_t qos = 0, bool retain = false, bool dup = false, uint16_t message_id = 0);
~Publish();
virtual bool send() override;
const uint8_t qos;
const uint16_t message_id;
PrintMux print;
Publisher & publisher;
};
virtual Publish begin_publish(const char * topic, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) = 0;
Publish begin_publish(const String & topic, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
return begin_publish(topic.c_str(), payload_size, qos, retain, message_id);
}
virtual bool publish(const char * topic, const void * payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
auto packet = begin_publish(get_c_str(topic), payload_size, qos, retain, message_id);
packet.write((const uint8_t *) payload, payload_size);
return packet.send();
}
bool publish(const String & topic, const void * payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
return publish(topic.c_str(), payload, payload_size, qos, retain, message_id);
}
template <typename TopicStringType, typename PayloadStringType>
bool publish(TopicStringType topic, PayloadStringType payload,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
return publish(get_c_str(topic), (const void *) get_c_str(payload), get_c_str_len(payload),
qos, retain, message_id);
}
virtual bool publish_P(const char * topic, PGM_P payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
auto packet = begin_publish(topic, payload_size, qos, retain, message_id);
packet.write_P(payload, payload_size);
return packet.send();
}
bool publish_P(const String & topic, PGM_P payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
return publish_P(topic.c_str(), payload, payload_size, qos, retain, message_id);
}
template <typename TopicStringType>
bool publish_P(TopicStringType topic, PGM_P payload,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) {
TRACE_FUNCTION
return publish_P(get_c_str(topic), payload, strlen_P(payload),
qos, retain, message_id);
}
protected:
virtual bool on_publish_complete(const Publish & publish) { return true; }
static const char * get_c_str(const char * string) { return string; }
static const char * get_c_str(const String & string) { return string.c_str(); }
static size_t get_c_str_len(const char * string) { return strlen(string); }
static size_t get_c_str_len(const String & string) { return string.length(); }
};
}

View File

@@ -0,0 +1,431 @@
#include "config.h"
#include "debug.h"
#include "server.h"
namespace {
class BufferClient: public ::Client {
public:
BufferClient(const void * ptr): ptr((const char *) ptr) { TRACE_FUNCTION }
// these methods are nop dummies
virtual int connect(IPAddress ip, uint16_t port) override final { TRACE_FUNCTION return 0; }
virtual int connect(const char * host, uint16_t port) override final { TRACE_FUNCTION return 0; }
#ifdef PICOMQTT_EXTRA_CONNECT_METHODS
virtual int connect(IPAddress ip, uint16_t port, int32_t timeout) override final { TRACE_FUNCTION return 0; }
virtual int connect(const char * host, uint16_t port, int32_t timeout) override final { TRACE_FUNCTION return 0; }
#endif
virtual size_t write(const uint8_t * buffer, size_t size) override final { TRACE_FUNCTION return 0; }
virtual size_t write(uint8_t value) override final { TRACE_FUNCTION return 0; }
virtual void flush() override final { TRACE_FUNCTION }
virtual void stop() override final { TRACE_FUNCTION }
// these methods are in jasager mode
virtual int available() override final { TRACE_FUNCTION return std::numeric_limits<int>::max(); }
virtual operator bool() override final { TRACE_FUNCTION return true; }
virtual uint8_t connected() override final { TRACE_FUNCTION return true; }
// actual reads implemented here
virtual int read(uint8_t * buf, size_t size) override {
memcpy(buf, ptr, size);
ptr += size;
return size;
}
virtual int read() override final {
TRACE_FUNCTION
uint8_t ret;
read(&ret, 1);
return ret;
}
virtual int peek() override final {
TRACE_FUNCTION
const int ret = read();
--ptr;
return ret;
}
protected:
const char * ptr;
};
class BufferClientP: public BufferClient {
public:
using BufferClient::BufferClient;
virtual int read(uint8_t * buf, size_t size) override {
memcpy_P(buf, ptr, size);
ptr += size;
return size;
}
};
}
namespace PicoMQTT {
Server::Client::Client(Server & server, ::Client * client)
:
SocketOwner(client),
Connection(*socket, 0, server.socket_timeout_millis), server(server), client_id("<unknown>") {
TRACE_FUNCTION
wait_for_reply(Packet::CONNECT, [this](IncomingPacket & packet) {
TRACE_FUNCTION
auto connack = [this](ConnectReturnCode crc) {
TRACE_FUNCTION
auto connack = build_packet(Packet::CONNACK, 0, 2);
connack.write_u8(0); /* session present always set to zero */
connack.write_u8(crc);
connack.send();
if (crc != CRC_ACCEPTED) {
Connection::client.stop();
}
};
{
// MQTT protocol identifier
char buf[4];
if (packet.read_u16() != 4) {
on_protocol_violation();
return;
}
packet.read((uint8_t *) buf, 4);
if (memcmp(buf, "MQTT", 4) != 0) {
on_protocol_violation();
return;
}
}
const uint8_t protocol_level = packet.read_u8();
if (protocol_level != 4) {
on_protocol_violation();
return;
}
const uint8_t connect_flags = packet.read_u8();
const bool has_user = connect_flags & (1 << 7);
const bool has_pass = connect_flags & (1 << 6);
const bool will_retain = connect_flags & (1 << 5);
const uint8_t will_qos = (connect_flags >> 3) & 0b11;
const bool has_will = connect_flags & (1 << 2);
/* const bool clean_session = connect_flags & (1 << 1); */
if ((has_pass && !has_user)
|| (will_qos > 2)
|| (!has_will && ((will_qos > 0) || will_retain))) {
on_protocol_violation();
return;
}
const unsigned long keep_alive_seconds = packet.read_u16();
keep_alive_millis = keep_alive_seconds ? (keep_alive_seconds * 1000 + this->server.keep_alive_tolerance_millis) : 0;
{
const size_t client_id_size = packet.read_u16();
if (client_id_size > PICOMQTT_MAX_CLIENT_ID_SIZE) {
connack(CRC_IDENTIFIER_REJECTED);
return;
}
char client_id_buffer[client_id_size + 1];
packet.read_string(client_id_buffer, client_id_size);
client_id = client_id_buffer;
}
if (client_id.isEmpty()) {
client_id = String((unsigned int)(this), HEX);
}
if (has_will) {
packet.ignore(packet.read_u16()); // will topic
packet.ignore(packet.read_u16()); // will payload
}
// read username
const size_t user_size = has_user ? packet.read_u16() : 0;
if (user_size > PICOMQTT_MAX_USERPASS_SIZE) {
connack(CRC_BAD_USERNAME_OR_PASSWORD);
return;
}
char user[user_size + 1];
if (user_size && !packet.read_string(user, user_size)) {
on_timeout();
return;
}
// read password
const size_t pass_size = has_pass ? packet.read_u16() : 0;
if (pass_size > PICOMQTT_MAX_USERPASS_SIZE) {
connack(CRC_BAD_USERNAME_OR_PASSWORD);
return;
}
char pass[pass_size + 1];
if (pass_size && !packet.read_string(pass, pass_size)) {
on_timeout();
return;
}
const auto connect_return_code = this->server.auth(
client_id.c_str(),
has_user ? user : nullptr, has_pass ? pass : nullptr);
connack(connect_return_code);
});
}
void Server::Client::on_message(const char * topic, IncomingPacket & packet) {
TRACE_FUNCTION
const size_t payload_size = packet.get_remaining_size();
auto publish = server.begin_publish(topic, payload_size);
// Always notify the server about the message
{
IncomingPublish incoming_publish(packet, publish);
server.on_message(topic, incoming_publish);
}
publish.send();
}
void Server::Client::on_subscribe(IncomingPacket & subscribe) {
TRACE_FUNCTION
const uint16_t message_id = subscribe.read_u16();
if ((subscribe.get_flags() != 0b0010) || !message_id) {
on_protocol_violation();
return;
}
std::list<uint8_t> suback_codes;
while (subscribe.get_remaining_size()) {
const size_t topic_size = subscribe.read_u16();
if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) {
subscribe.ignore(topic_size);
subscribe.read_u8();
suback_codes.push_back(0x80);
} else {
char topic[topic_size + 1];
if (!subscribe.read_string(topic, topic_size)) {
// connection error
return;
}
uint8_t qos = subscribe.read_u8();
if (qos > 2) {
on_protocol_violation();
return;
}
this->subscribe(topic);
server.on_subscribe(client_id.c_str(), topic);
suback_codes.push_back(0);
}
}
auto suback = build_packet(Packet::SUBACK, 0, 2 + suback_codes.size());
suback.write_u16(message_id);
for (uint8_t code : suback_codes) {
suback.write_u8(code);
}
suback.send();
}
void Server::Client::on_unsubscribe(IncomingPacket & unsubscribe) {
TRACE_FUNCTION
const uint16_t message_id = unsubscribe.read_u16();
if ((unsubscribe.get_flags() != 0b0010) || !message_id) {
on_protocol_violation();
return;
}
while (unsubscribe.get_remaining_size()) {
const size_t topic_size = unsubscribe.read_u16();
if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) {
unsubscribe.ignore(topic_size);
} else {
char topic[topic_size + 1];
if (!unsubscribe.read_string(topic, topic_size)) {
// connection error
return;
}
server.on_unsubscribe(client_id.c_str(), topic);
this->unsubscribe(topic);
}
}
auto unsuback = build_packet(Packet::UNSUBACK, 0, 2);
unsuback.write_u16(message_id);
unsuback.send();
}
const char * Server::Client::get_subscription_pattern(Server::Client::SubscriptionId id) const {
for (const auto & pattern : subscriptions)
if (pattern.id == id) {
return pattern.c_str();
}
return nullptr;
}
Server::Client::SubscriptionId Server::Client::get_subscription(const char * topic) const {
TRACE_FUNCTION
for (const auto & pattern : subscriptions)
if (topic_matches(pattern.c_str(), topic)) {
return pattern.id;
}
return 0;
}
Server::Client::SubscriptionId Server::Client::subscribe(const String & topic_filter) {
TRACE_FUNCTION
const Subscription subscription(topic_filter.c_str());
subscriptions.insert(subscription);
return subscription.id;
}
void Server::Client::unsubscribe(const String & topic_filter) {
TRACE_FUNCTION
subscriptions.erase(topic_filter.c_str());
}
void Server::Client::handle_packet(IncomingPacket & packet) {
TRACE_FUNCTION
switch (packet.get_type()) {
case Packet::PINGREQ:
build_packet(Packet::PINGRESP).send();
return;
case Packet::SUBSCRIBE:
on_subscribe(packet);
return;
case Packet::UNSUBSCRIBE:
on_unsubscribe(packet);
return;
default:
Connection::handle_packet(packet);
return;
}
}
void Server::Client::loop() {
TRACE_FUNCTION
if (keep_alive_millis && (get_millis_since_last_read() > keep_alive_millis)) {
// ping timeout
on_timeout();
return;
}
Connection::loop();
}
Server::IncomingPublish::IncomingPublish(IncomingPacket & packet, Publish & publish)
: IncomingPacket(std::move(packet)), publish(publish) {
TRACE_FUNCTION
}
Server::IncomingPublish::~IncomingPublish() {
TRACE_FUNCTION
pos += publish.write_from_client(client, get_remaining_size());
}
int Server::IncomingPublish::read(uint8_t * buf, size_t size) {
TRACE_FUNCTION
const int ret = IncomingPacket::read(buf, size);
if (ret > 0) {
publish.write(buf, ret);
}
return ret;
}
int Server::IncomingPublish::read() {
TRACE_FUNCTION
const int ret = IncomingPacket::read();
if (ret >= 0) {
publish.write(ret);
}
return ret;
}
Server::Server(std::unique_ptr<ServerSocketInterface> server)
: keep_alive_tolerance_millis(10 * 1000), socket_timeout_millis(5 * 1000), server(std::move(server)) {
TRACE_FUNCTION
}
void Server::begin() {
TRACE_FUNCTION
server->begin();
}
void Server::loop() {
TRACE_FUNCTION
::Client * client_ptr = server->accept_client();
if (client_ptr) {
clients.push_back(std::unique_ptr<Client>(new Client(*this, client_ptr)));
on_connected(clients.back()->get_client_id());
}
for (auto it = clients.begin(); it != clients.end();) {
Client & client = **it;
client.loop();
if (!client.connected()) {
on_disconnected(client.get_client_id());
clients.erase(it++);
} else {
++it;
}
}
}
PrintMux Server::get_subscribed(const char * topic) {
TRACE_FUNCTION
PrintMux ret;
for (auto & client_ptr : clients) {
if (client_ptr->get_subscription(topic)) {
ret.add(client_ptr->get_print());
}
}
return ret;
}
Publisher::Publish Server::begin_publish(const char * topic, const size_t payload_size,
uint8_t, bool, uint16_t) {
TRACE_FUNCTION
return Publish(*this, get_subscribed(topic), topic, payload_size);
}
void Server::on_message(const char * topic, IncomingPacket & packet) {
TRACE_FUNCTION
fire_message_callbacks(topic, packet);
}
bool ServerLocalSubscribe::publish(const char * topic, const void * payload, const size_t payload_size,
uint8_t qos, bool retain, uint16_t message_id) {
TRACE_FUNCTION
const bool ret = Server::publish(topic, payload, payload_size, qos, retain, message_id);
BufferClient buffer(payload);
IncomingPacket packet(IncomingPacket::PUBLISH, 0, payload_size, buffer);
fire_message_callbacks(topic, packet);
return ret;
}
bool ServerLocalSubscribe::publish_P(const char * topic, PGM_P payload, const size_t payload_size,
uint8_t qos, bool retain, uint16_t message_id) {
TRACE_FUNCTION
const bool ret = Server::publish_P(topic, payload, payload_size, qos, retain, message_id);
BufferClientP buffer((void *) payload);
IncomingPacket packet(IncomingPacket::PUBLISH, 0, payload_size, buffer);
fire_message_callbacks(topic, packet);
return ret;
}
}

View File

@@ -0,0 +1,230 @@
#pragma once
#include <list>
#include <set>
#include <Arduino.h>
#if defined(ESP32)
#include <WiFi.h>
#elif defined(ESP8266)
#include <ESP8266WiFi.h>
#else
#error "This board is not supported."
#endif
#include "debug.h"
#include "incoming_packet.h"
#include "connection.h"
#include "publisher.h"
#include "subscriber.h"
#include "pico_interface.h"
#include "utils.h"
namespace PicoMQTT {
class ServerSocketInterface {
public:
ServerSocketInterface() {}
virtual ~ServerSocketInterface() {}
ServerSocketInterface(const ServerSocketInterface &) = delete;
const ServerSocketInterface & operator=(const ServerSocketInterface &) = delete;
virtual void begin() = 0;
virtual ::Client * accept_client() = 0;
};
template <typename Server>
class ServerSocket: public ServerSocketInterface, public Server {
public:
using Server::Server;
virtual ::Client * accept_client() override {
TRACE_FUNCTION
auto client = Server::accept();
if (!client) {
// no connection
return nullptr;
}
return new decltype(client)(client);
};
virtual void begin() override {
TRACE_FUNCTION
Server::begin();
}
};
template <typename Server>
class ServerSocketProxy: public ServerSocketInterface {
public:
Server & server;
ServerSocketProxy(Server & server): server(server) {}
virtual ::Client * accept_client() override {
TRACE_FUNCTION
auto client = server.accept();
if (!client) {
// no connection
return nullptr;
}
return new decltype(client)(client);
};
virtual void begin() override {
TRACE_FUNCTION
server.begin();
}
};
class ServerSocketMux: public ServerSocketInterface {
public:
template <typename... Targs>
ServerSocketMux(Targs & ... Fargs) {
add(Fargs...);
}
virtual ::Client * accept_client() override {
TRACE_FUNCTION
for (auto & server : servers) {
auto client = server->accept_client();
if (client) {
// no connection
return client;
}
}
return nullptr;
};
virtual void begin() override {
TRACE_FUNCTION
for (auto & server : servers) {
server->begin();
}
}
protected:
template <typename Server>
void add(Server & server) {
servers.push_back(std::unique_ptr<ServerSocketInterface>(new ServerSocketProxy<Server>(server)));
}
template <typename Server, typename... Targs>
void add(Server & server, Targs & ... Fargs) {
add(server);
add(Fargs...);
}
std::vector<std::unique_ptr<ServerSocketInterface>> servers;
};
class Server: public PicoMQTTInterface, public Publisher, public SubscribedMessageListener {
public:
class Client: public SocketOwner<std::unique_ptr<::Client>>, public Connection, public Subscriber {
public:
Client(Server & server, ::Client * client);
void on_message(const char * topic, IncomingPacket & packet) override;
Print & get_print() { return Connection::client; }
const char * get_client_id() const { return client_id.c_str(); }
virtual void loop() override;
virtual const char * get_subscription_pattern(SubscriptionId id) const override;
virtual SubscriptionId get_subscription(const char * topic) const override;
virtual SubscriptionId subscribe(const String & topic_filter) override;
virtual void unsubscribe(const String & topic_filter) override;
protected:
Server & server;
String client_id;
std::set<Subscription> subscriptions;
virtual void on_subscribe(IncomingPacket & packet);
virtual void on_unsubscribe(IncomingPacket & packet);
virtual void handle_packet(IncomingPacket & packet) override;
};
class IncomingPublish: public IncomingPacket {
public:
IncomingPublish(IncomingPacket & packet, Publish & publish);
IncomingPublish(const IncomingPublish &) = delete;
~IncomingPublish();
virtual int read(uint8_t * buf, size_t size) override;
virtual int read() override;
protected:
Publish & publish;
};
Server(std::unique_ptr<ServerSocketInterface> socket);
Server(uint16_t port = 1883)
: Server(new ServerSocket<::WiFiServer>(port)) {
TRACE_FUNCTION
}
template <typename ServerType>
Server(ServerType & server)
: Server(new ServerSocketProxy<ServerType>(server)) {
TRACE_FUNCTION
}
template <typename ServerType, typename... Targs>
Server(ServerType & server, Targs & ... Fargs): Server(new ServerSocketMux(server, Fargs...)) {
TRACE_FUNCTION
}
void begin() override;
void loop() override;
using Publisher::begin_publish;
virtual Publish begin_publish(const char * topic, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override;
unsigned long keep_alive_tolerance_millis;
unsigned long socket_timeout_millis;
protected:
Server(ServerSocketInterface * socket)
: Server(std::unique_ptr<ServerSocketInterface>(socket)) {
TRACE_FUNCTION
}
virtual void on_message(const char * topic, IncomingPacket & packet);
virtual ConnectReturnCode auth(const char * client_id, const char * username, const char * password) { return CRC_ACCEPTED; }
virtual void on_connected(const char * client_id) {}
virtual void on_disconnected(const char * client_id) {}
virtual void on_subscribe(const char * client_id, const char * topic) {}
virtual void on_unsubscribe(const char * client_id, const char * topic) {}
virtual PrintMux get_subscribed(const char * topic);
std::unique_ptr<ServerSocketInterface> server;
std::list<std::unique_ptr<Client>> clients;
};
class ServerLocalSubscribe: public Server {
public:
using Server::Server;
using Server::publish;
using Server::publish_P;
virtual bool publish(const char * topic, const void * payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override;
virtual bool publish_P(const char * topic, PGM_P payload, const size_t payload_size,
uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override;
};
}

View File

@@ -0,0 +1,161 @@
#include "subscriber.h"
#include "incoming_packet.h"
#include "debug.h"
namespace PicoMQTT {
String Subscriber::get_topic_element(const char * topic, size_t index) {
while (index && topic[0]) {
if (topic++[0] == '/') {
--index;
}
}
if (!topic[0]) {
return "";
}
const char * end = topic;
while (*end && *end != '/') {
++end;
}
String ret;
ret.concat(topic, end - topic);
return ret;
}
String Subscriber::get_topic_element(const String & topic, size_t index) {
TRACE_FUNCTION
return get_topic_element(topic.c_str(), index);
}
bool Subscriber::topic_matches(const char * p, const char * t) {
TRACE_FUNCTION
// TODO: Special handling of the $ prefix
while (true) {
switch (*p) {
case '\0':
// end of pattern reached
// TODO: check for '/#' suffix
return (*t == '\0');
case '#':
// multilevel wildcard
if (*t == '\0') {
return false;
}
return true;
case '+':
// single level wildcard
while (*t && *t != '/') {
++t;
}
++p;
break;
default:
// regular match
if (*p != *t) {
return false;
}
++p;
++t;
}
}
}
const char * SubscribedMessageListener::get_subscription_pattern(SubscriptionId id) const {
TRACE_FUNCTION
for (const auto & kv : subscriptions) {
if (kv.first.id == id) {
return kv.first.c_str();
}
}
return nullptr;
}
Subscriber::SubscriptionId SubscribedMessageListener::get_subscription(const char * topic) const {
TRACE_FUNCTION
for (const auto & kv : subscriptions) {
if (topic_matches(kv.first.c_str(), topic)) {
return kv.first.id;
}
}
return 0;
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter) {
TRACE_FUNCTION
return subscribe(topic_filter, [this](const char * topic, IncomingPacket & packet) { on_extra_message(topic, packet); });
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, MessageCallback callback) {
TRACE_FUNCTION
unsubscribe(topic_filter);
auto pair = subscriptions.emplace(std::make_pair(Subscription(topic_filter), callback));
return pair.first->first.id;
}
void SubscribedMessageListener::unsubscribe(const String & topic_filter) {
TRACE_FUNCTION
subscriptions.erase(topic_filter);
}
void SubscribedMessageListener::fire_message_callbacks(const char * topic, IncomingPacket & packet) {
TRACE_FUNCTION
for (const auto & kv : subscriptions) {
if (topic_matches(kv.first.c_str(), topic)) {
kv.second((char *) topic, packet);
return;
}
}
on_extra_message(topic, packet);
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter,
std::function<void(char *, void *, size_t)> callback, size_t max_size) {
TRACE_FUNCTION
return subscribe(topic_filter, [this, callback, max_size](char * topic, IncomingPacket & packet) {
const size_t payload_size = packet.get_remaining_size();
if (payload_size >= max_size) {
on_message_too_big(topic, packet);
return;
}
char payload[payload_size + 1];
if (packet.read((uint8_t *) payload, payload_size) != (int) payload_size) {
// connection error, ignore
return;
}
payload[payload_size] = '\0';
callback(topic, payload, payload_size);
});
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter,
std::function<void(char *, char *)> callback, size_t max_size) {
TRACE_FUNCTION
return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) {
callback(topic, (char *) payload);
}, max_size);
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter,
std::function<void(char *)> callback, size_t max_size) {
TRACE_FUNCTION
return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) {
callback((char *) payload);
}, max_size);
}
Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter,
std::function<void(void *, size_t)> callback, size_t max_size) {
TRACE_FUNCTION
return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) {
callback(payload, payload_size);
}, max_size);
}
};

View File

@@ -0,0 +1,73 @@
#pragma once
#include <functional>
#include <map>
#include <Arduino.h>
#include "autoid.h"
#include "config.h"
namespace PicoMQTT {
class IncomingPacket;
class Subscriber {
public:
typedef AutoId::Id SubscriptionId;
static bool topic_matches(const char * topic_filter, const char * topic);
static String get_topic_element(const char * topic, size_t index);
static String get_topic_element(const String & topic, size_t index);
virtual const char * get_subscription_pattern(SubscriptionId id) const = 0;
virtual SubscriptionId get_subscription(const char * topic) const = 0;
virtual SubscriptionId subscribe(const String & topic_filter) = 0;
virtual void unsubscribe(const String & topic_filter) = 0;
void unsubscribe(SubscriptionId id) { unsubscribe(get_subscription_pattern(id)); }
protected:
class Subscription: public String, public AutoId {
public:
using String::String;
Subscription(const String & str): Subscription(str.c_str()) {}
};
};
class SubscribedMessageListener: public Subscriber {
public:
// NOTE: None of the callback functions use const arguments for wider compatibility. It's still OK (and
// recommended) to use callbacks which take const arguments. Similarly with Strings.
typedef std::function<void(char * topic, IncomingPacket & packet)> MessageCallback;
virtual const char * get_subscription_pattern(SubscriptionId id) const override;
virtual SubscriptionId get_subscription(const char * topic) const override;
virtual SubscriptionId subscribe(const String & topic_filter) override;
virtual SubscriptionId subscribe(const String & topic_filter, MessageCallback callback);
SubscriptionId subscribe(const String & topic_filter, std::function<void(char *, void *, size_t)> callback,
size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE);
SubscriptionId subscribe(const String & topic_filter, std::function<void(char *, char *)> callback,
size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE);
SubscriptionId subscribe(const String & topic_filter, std::function<void(void *, size_t)> callback,
size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE);
SubscriptionId subscribe(const String & topic_filter, std::function<void(char *)> callback,
size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE);
virtual void unsubscribe(const String & topic_filter) override;
virtual void on_extra_message(const char * topic, IncomingPacket & packet) {}
virtual void on_message_too_big(const char * topic, IncomingPacket & packet) {}
protected:
void fire_message_callbacks(const char * topic, IncomingPacket & packet);
std::map<Subscription, MessageCallback> subscriptions;
};
}

View File

@@ -0,0 +1,16 @@
#pragma once
#include <utility>
namespace PicoMQTT {
template <typename T>
struct SocketOwner {
SocketOwner() {}
template <typename ...Args>
SocketOwner(Args && ...args): socket(std::forward<Args>(args)...) {}
virtual ~SocketOwner() {}
T socket;
};
}