1 ///
2 module drmi;
3 
4 import vibe.data.json;
5 
6 ///
7 alias RMIData = Json;
8 ///
9 RMIData serialize(T)(auto ref const T val) { return val.serializeToJson; }
10 ///
11 T deserialize(T)(auto ref const RMIData d) { return d.deserializeJson!T; }
12 ///
13 auto as(T)(auto ref const RMIData d) @property { return d.deserialize!T; }
14 
15 ///
16 RMIData rmiEmptyArrayData() { return Json.emptyArray; }
17 
18 ///
19 struct RMICall
20 {
21     ///
22     string caller;
23     ///
24     string func;
25     ///
26     long ts;
27     ///
28     RMIData data;
29 }
30 
31 ///
32 struct RMIResponse
33 {
34     ///
35     uint status;
36     ///
37     RMICall call;
38     ///
39     RMIData data;
40 }
41 
42 ///
43 class RMIException : Exception
44 {
45     ///
46     RMIData data;
47     ///
48     this(RMIData data)
49     {
50         this.data = data;
51         super(data.toString);
52     }
53 }
54 
55 ///
56 class RMITimeoutException : Exception
57 {
58     import std.conv : text;
59     ///
60     RMICall call;
61     ///
62     this(RMICall c) { call = c; super(text(c)); }
63 }
64 
65 ///
66 interface RMICom
67 {
68     ///
69     RMIResponse process(RMICall call);
70 }
71 
72 ///
73 class RMISkeleton(T) : RMICom
74     if (is(T == interface))
75 {
76 protected:
77     T server;
78 
79 public:
80 
81     ///
82     this(T server)
83     {
84         import std.exception : enforce;
85         this.server = enforce(server, "server is null");
86     }
87 
88     override RMIResponse process(RMICall call)
89     {
90         import std.meta;
91         import std.traits;
92 
93         template ov(string s) { alias ov = AliasSeq!(__traits(getOverloads, T, s)); }
94 
95         switch (call.func)
96         {
97             foreach (n, func; staticMap!(ov, __traits(derivedMembers, T)))
98             {
99                 case rmiFunctionName!func:
100                     try
101                     {
102                         auto params = Parameters!func.init;
103                         foreach (i, type; Parameters!func)
104                             params[i] = call.data[i].as!type;
105                         
106                         RMIData resData = null;
107                         enum callstr = "server."~__traits(identifier, func)~"(params)";
108                         static if(is(ReturnType!func == void)) mixin(callstr~";");
109                         else resData = mixin(callstr).serialize;
110 
111                         return RMIResponse(0, call, resData);
112                     }
113                     catch (Throwable e)
114                         return RMIResponse(2, call, e.msg.serialize);
115             }
116             default:
117                 return RMIResponse(1, call, "unknown func".serialize);
118         }
119     }
120 }
121 
122 ///
123 interface RMIStubCom : RMICom
124 {
125     ///
126     string caller() const @property;
127 }
128 
129 ///
130 class RMIStub(T) : T
131 {
132 protected:
133     RMIStubCom com;
134 public:
135     ///
136     this(RMIStubCom com) { this.com = com; }
137 
138     private mixin template overrideIface()
139     {
140         private enum string packCode =
141             q{
142                 enum dummy;
143                 alias self = AliasSeq!(__traits(parent, dummy))[0];
144 
145                 static fname = rmiFunctionName!self;
146 
147                 RMICall call;
148                 call.caller = com.caller;
149                 call.func = fname;
150                 call.ts = Clock.currStdTime;
151                 call.data = rmiEmptyArrayData;
152 
153                 foreach (p; ParameterIdentifierTuple!self)
154                     call.data ~= mixin(p).serialize;
155 
156                 auto result = com.process(call);
157 
158                 enforce(result.status == 0, new RMIException(result.data));
159 
160                 static if (!is(typeof(return) == void))
161                     return result.data.as!(typeof(return));
162             };
163 
164         import std.datetime : Clock;
165         import std.exception : enforce;
166         import std.meta : staticMap;
167         import std.traits : ReturnType, AliasSeq, Parameters, ParameterIdentifierTuple,
168                             functionAttributes, FunctionAttribute;
169 
170         private mixin template impl(F...)
171         {
172             private static string trueParameters(alias FNC)()
173             {
174                 import std.conv : text;
175                 import std.string : join;
176                 string[] ret;
177                 foreach (i, param; Parameters!FNC)
178                     ret ~= `Parameters!(F[0])[`~text(i)~`] __param_` ~ text(i);
179                 return ret.join(", ");
180             }
181 
182             private static string getAttributesString(alias FNC)()
183             {
184                 import std.string : join;
185                 string[] ret;
186                 // TODO
187                 ret ~= functionAttributes!FNC & FunctionAttribute.property ? "@property" : "";
188                 return ret.join(" ");
189             }
190 
191             static if (F.length == 1)
192             {
193                 mixin("override ReturnType!(F[0]) " ~ __traits(identifier, F[0]) ~
194                       `(` ~ trueParameters!(F[0]) ~ `) ` ~ getAttributesString!(F[0]) ~
195                       ` { ` ~ packCode ~ `}`);
196             }
197             else
198             {
199                 mixin impl!(F[0..$/2]);
200                 mixin impl!(F[$/2..$]);
201             }
202         }
203 
204         private template getOverloads(string s)
205         { alias getOverloads = AliasSeq!(__traits(getOverloads, T, s)); }
206 
207         mixin impl!(staticMap!(getOverloads, __traits(derivedMembers, T)));
208     }
209 
210     mixin overrideIface;
211 }
212 
213 private version (unittest)
214 {
215     struct Point { double x, y, z; }
216 
217     interface Test
218     {
219         int foo(string abc, int xyz);
220         string foo(string str);
221         string bar(double val);
222         double len(Point pnt);
223         string state() @property;
224         void state(string s) @property;
225     }
226 
227     class Impl : Test
228     {
229         string _state;
230     override:
231         string foo(string str) { return "<" ~ str ~ ">"; }
232         int foo(string abc, int xyz) { return cast(int)(abc.length * xyz); }
233         string bar(double val) { return val > 3.14 ? "big" : "small"; }
234         double len(Point pnt)
235         {
236             import std.math;
237             return sqrt(pnt.x^^2 + pnt.y^^2 + pnt.z^^2);
238         }
239         string state() @property { return _state; }
240         void state(string s) @property { _state = s; }
241     }
242 }
243 
244 unittest
245 {
246     auto rea = new Impl;
247     auto ske = new RMISkeleton!Test(rea);
248     auto cli = new RMIStub!Test(new class RMIStubCom
249     {
250         string caller() const @property { return "fake caller"; }
251         RMIResponse process(RMICall call) { return ske.process(call); }
252     });
253 
254     assert(rea.foo("hello", 123) == cli.foo("hello", 123));
255     assert(rea.bar(2.71) == cli.bar(2.71));
256     assert(rea.bar(3.1415) == cli.bar(3.1415));
257     assert(rea.foo("okda") == cli.foo("okda"));
258     assert(rea.len(Point(1,2,3)) == cli.len(Point(1,2,3)));
259 
260     static str = "foo";
261     cli.state = str;
262     assert(rea.state == str);
263     assert(cli.state == str);
264 }
265 
266 private string rmiFunctionName(alias func)()
267 {
268     import std.meta : staticMap;
269     import std.traits : Parameters;
270     import std.string : join;
271     import std.algorithm : canFind;
272 
273     checkFunction!func;
274 
275     template s4t(X) { enum s4t = X.stringof; }
276 
277     static if (Parameters!func.length)
278         return __traits(identifier, func) ~ "(" ~ [staticMap!(s4t, Parameters!func)].join(",") ~ ")";
279     else
280         return __traits(identifier, func) ~ "()";
281 }
282 
283 private void checkFunction(alias func)()
284 {
285     import std.algorithm : find;
286     import std.traits : hasFunctionAttributes;
287     enum funcstr = __traits(identifier, __traits(parent, __traits(parent, func))) ~ ":" ~ 
288                     __traits(identifier, __traits(parent, func))
289                     ~ "." ~ __traits(identifier, func);
290     static assert(!hasFunctionAttributes!(func, "@safe"), "@safe not allowed: " ~ funcstr);
291     static assert(!hasFunctionAttributes!(func,  "pure"), "pure not allowed: " ~ funcstr);
292     static assert(!hasFunctionAttributes!(func, "@nogc"), "@nogc not allowed: " ~ funcstr);
293 }
294 
295 
296 unittest
297 {
298     static auto i = [0];
299     void foo(int a, double b, string c) @system { i ~= 1; }
300     static assert(rmiFunctionName!foo == "foo(int,double,string)");
301     void bar() @system { i ~= 2; }
302     static assert(rmiFunctionName!bar == "bar()");
303 }