1 module drmi.mqtt.subscriber;
2 
3 import mqttd;
4 
5 import std.algorithm : map;
6 import std.array : array;
7 
8 class Subscriber : MqttClient
9 {
10 protected:
11     static struct CB
12     {
13         string pattern;
14         void delegate(string, const(ubyte)[]) func;
15     }
16 
17     CB[] slist;
18     QoSLevel qos;
19 
20 public:
21     this(Settings s, QoSLevel qos) { super(s); this.qos = qos; }
22 
23     override void onPublish(Publish msg)
24     {
25         super.onPublish(msg);
26         () @trusted
27         {
28             foreach (cb; slist)
29                 if (match(msg.topic, cb.pattern))
30                     cb.func(msg.topic, msg.payload);
31         }();
32     }
33 
34     void subscribe(string pattern, void delegate(string, const(ubyte)[]) cb)
35     {
36         slist ~= CB(pattern, cb);
37         if (this.connected)
38             super.subscribe([pattern], qos);
39     }
40 
41     override void onConnAck(ConnAck ca)
42     {
43         super.onConnAck(ca);
44         () @trusted
45         {
46             super.subscribe(slist.map!(a=>a.pattern).array, qos);
47         } ();
48     }
49 }
50 
51 private bool match(string topic, string pattern)
52 {
53     enum ANY = "+";
54     enum ANYLVL = "#";
55 
56     import std.algorithm : find;
57     import std.exception : enforce;
58     import std.string : split;
59 
60     debug (TopicMatchDebug) import std.stdio;
61 
62     auto pat = pattern.split("/");
63 
64     auto fanylvl = pat.find(ANYLVL);
65     enforce(fanylvl.length <= 1, "# must be final char");
66 
67     auto top = topic.split("/");
68 
69     if (fanylvl.length == 0 && pat.length != top.length)
70     {
71         debug (TopicMatchDebug) stderr.writeln("no ANYLVL and mismatch levels: ", pat, " ", top, " returns false");
72         return false;
73     }
74 
75     foreach (i, e; pat)
76     {
77         if (i >= top.length)
78         {
79             debug (TopicMatchDebug) stderr.writeln("pat length more that top: ", pat, " ", top, " returns false");
80             return false;
81         }
82         if (e != top[i])
83         {
84             if (i == pat.length - 1 && e == ANYLVL)
85             {
86                 debug (TopicMatchDebug) stderr.writeln("matched: ", pat, " ", top);
87                 return true;
88             }
89             else if (e == ANY) { /+ pass +/ }
90             else
91             {
92                 debug (TopicMatchDebug) stderr.writefln("%s %s mismatch %s and %s (idx: %d) returns false", pat, top, e, top[i], i);
93                 return false;
94             }
95         }
96     }
97     debug (TopicMatchDebug) stderr.writeln("matched: ", pat, " ", top);
98     return true;
99 }
100 
101 unittest
102 {
103     import std.exception;
104     assertThrown(match("any", "#/"));
105     assertNotThrown(match("any", "/#"));
106 
107     assert( match("a/b/c/d", "a/b/c/d"));
108     assert( match("a/b/c/d", "+/b/c/d"));
109     assert( match("a/b/c/d", "a/+/c/d"));
110     assert( match("a/b/c/d", "a/b/+/d"));
111     assert( match("a/b/c/d", "a/b/c/+"));
112     assert( match("a/b/c/d", "+/b/c/+"));
113     assert( match("a/b/c/d", "a/+/c/+"));
114     assert( match("a/b/c/d", "a/b/+/+"));
115     assert( match("a/b/c/d", "+/b/+/+"));
116     assert( match("a/b/c/d", "a/+/+/+"));
117     assert( match("a/b/c/d", "+/+/+/+"));
118 
119     assert(!match("a/b/c/d", "b/+/c/d"));
120     assert(!match("a/b/c/d", "a/b/c"));
121     assert(!match("a/b/c/d", "+/+/+"));
122 
123     assert(!match("a/b/c/d", "+/+/+"));
124     assert( match("a/b/c/d", "+/+/#"));
125     assert(!match("a/b", "+/+/#"));
126     assert( match("a/b/", "+/+/#"));
127 
128     assert( match("/a/b", "#"));
129     assert( match("/a//b", "#"));
130     assert( match("a//b", "#"));
131     assert( match("a/b", "#"));
132     assert( match("a/b/", "#"));
133     assert( match("/a/b", "/#"));
134     assert( match("/a/b", "/+/b"));
135     assert( match("/a/b", "+/+/b"));
136     assert( match("/a//b", "/#"));
137     assert( match("/a//b", "/a/+/b"));
138 }