1 module drmi.mqtt.client;
2 
3 import mqttd;
4 
5 import std.algorithm : map;
6 import std.array : array;
7 
8 class IMClient : MqttClient
9 {
10 protected:
11     static struct CB
12     {
13         string pattern;
14         void delegate(string, const(ubyte)[]) func;
15         QoSLevel qos;
16     }
17 
18     CB[] slist;
19 
20 public:
21     this(Settings s) { super(s); }
22 
23     override void onPublish(Publish msg)
24     {
25         super.onPublish(msg);
26         () @trusted
27         {
28             foreach (cb; slist)
29                 if (match(msg.topic, cb.pattern))
30                     cb.func(msg.topic, msg.payload);
31         }();
32     }
33 
34     void subscribe(string pattern, void delegate(string, const(ubyte)[]) cb, QoSLevel qos)
35     {
36         slist ~= CB(pattern, cb, qos);
37     }
38 
39     override void onConnAck(ConnAck ca)
40     {
41         import std.algorithm : filter, map;
42         super.onConnAck(ca);
43         () @trusted
44         {
45             void fltr(QoSLevel lvl)
46             {
47                 auto lst = slist.filter!(a=>a.qos==lvl).map!(a=>a.pattern).array;
48                 if (lst.length) super.subscribe(lst, lvl);
49             }
50 
51             fltr(QoSLevel.QoS0);
52             fltr(QoSLevel.QoS1);
53             fltr(QoSLevel.QoS2);
54         } ();
55     }
56 }
57 
58 private bool match(string topic, string pattern)
59 {
60     enum ANY = "+";
61     enum ANYLVL = "#";
62 
63     import std.algorithm : find;
64     import std.exception : enforce;
65     import std.string : split;
66 
67     debug (TopicMatchDebug) import std.stdio;
68 
69     auto pat = pattern.split("/");
70 
71     auto fanylvl = pat.find(ANYLVL);
72     enforce(fanylvl.length <= 1, "# must be final char");
73 
74     auto top = topic.split("/");
75 
76     if (fanylvl.length == 0 && pat.length != top.length)
77     {
78         debug (TopicMatchDebug) stderr.writeln("no ANYLVL and mismatch levels: ", pat, " ", top, " returns false");
79         return false;
80     }
81 
82     foreach (i, e; pat)
83     {
84         if (i >= top.length)
85         {
86             debug (TopicMatchDebug) stderr.writeln("pat length more that top: ", pat, " ", top, " returns false");
87             return false;
88         }
89         if (e != top[i])
90         {
91             if (i == pat.length - 1 && e == ANYLVL)
92             {
93                 debug (TopicMatchDebug) stderr.writeln("matched: ", pat, " ", top);
94                 return true;
95             }
96             else if (e == ANY) { /+ pass +/ }
97             else
98             {
99                 debug (TopicMatchDebug) stderr.writefln("%s %s mismatch %s and %s (idx: %d) returns false", pat, top, e, top[i], i);
100                 return false;
101             }
102         }
103     }
104     debug (TopicMatchDebug) stderr.writeln("matched: ", pat, " ", top);
105     return true;
106 }
107 
108 unittest
109 {
110     import std.exception;
111     assertThrown(match("any", "#/"));
112     assertNotThrown(match("any", "/#"));
113 
114     assert( match("a/b/c/d", "a/b/c/d"));
115     assert( match("a/b/c/d", "+/b/c/d"));
116     assert( match("a/b/c/d", "a/+/c/d"));
117     assert( match("a/b/c/d", "a/b/+/d"));
118     assert( match("a/b/c/d", "a/b/c/+"));
119     assert( match("a/b/c/d", "+/b/c/+"));
120     assert( match("a/b/c/d", "a/+/c/+"));
121     assert( match("a/b/c/d", "a/b/+/+"));
122     assert( match("a/b/c/d", "+/b/+/+"));
123     assert( match("a/b/c/d", "a/+/+/+"));
124     assert( match("a/b/c/d", "+/+/+/+"));
125 
126     assert(!match("a/b/c/d", "b/+/c/d"));
127     assert(!match("a/b/c/d", "a/b/c"));
128     assert(!match("a/b/c/d", "+/+/+"));
129 
130     assert(!match("a/b/c/d", "+/+/+"));
131     assert( match("a/b/c/d", "+/+/#"));
132     assert(!match("a/b", "+/+/#"));
133     assert( match("a/b/", "+/+/#"));
134 
135     assert( match("/a/b", "#"));
136     assert( match("/a//b", "#"));
137     assert( match("a//b", "#"));
138     assert( match("a/b", "#"));
139     assert( match("a/b/", "#"));
140     assert( match("/a/b", "/#"));
141     assert( match("/a/b", "/+/b"));
142     assert( match("/a/b", "+/+/b"));
143     assert( match("/a//b", "/#"));
144     assert( match("/a//b", "/a/+/b"));
145 }