1# Copyright 2021 Nikita Melekhin. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5class Message:
6 def __init__(self, name, id, reply_id, decoder_magic, params, protected=False):
7 self.name = name
8 self.id = id
9 self.reply_id = reply_id
10 self.decoder_magic = decoder_magic
11 self.params = params
12 self.protected = protected
13
14
15class Generator:
16
17 def __init__(self):
18 self.output = None
19
20 def out(self, str, tabs=0):
21 for i in range(tabs):
22 self.output.write(" ")
23 self.output.write(str)
24 self.output.write("\n")
25
26 def params_readable(self, params):
27 res = ""
28 if len(params) > 0:
29 for i in params:
30 res += "{0} {1},".format(i[0], i[1])
31 res = res[:-1]
32 return res
33
34 def message_create_std_funcs(self, msg):
35 self.out("int id() const override {{ return {0}; }}".format(msg.id), 1)
36 self.out("int reply_id() const override {{ return {0}; }}".format(
37 msg.reply_id), 1)
38 if msg.protected:
39 self.out("int key() const override { return m_key; }", 1)
40 self.out("int decoder_magic() const override {{ return {0}; }}".format(
41 msg.decoder_magic), 1)
42 for i in msg.params:
43 if i[0] in ["int", "uint32_t", "bool", "int32_t"]:
44 self.out(
45 "{0} {1}() const {{ return m_{1}; }}".format(i[0], i[1]), 1)
46 else:
47 self.out(
48 "{0}& {1}() {{ return m_{1}; }}".format(i[0], i[1]), 1)
49
50 def message_create_vars(self, msg):
51 if msg.protected:
52 self.out("message_key_t m_key;", 1)
53 for i in msg.params:
54 self.out("{0} m_{1};".format(i[0], i[1]), 1)
55
56 def message_create_constructor(self, msg):
57 params = msg.params
58 if msg.protected:
59 params = [('message_key_t', 'key')] + msg.params
60 res = "{0}({1})".format(msg.name, self.params_readable(params))
61 if len(params) > 0:
62 self.out(res, 1)
63 sign = ':'
64 for i in params:
65 self.out("{0} m_{1}({1})".format(sign, i[1]), 2)
66 sign = ','
67
68 self.out("{", 1)
69 self.out("}", 1)
70 else:
71 self.out(res+" {}", 1)
72
73 def message_create_encoder(self, msg):
74 self.out("EncodedMessage encode() const override".format(
75 msg.decoder_magic), 1)
76 self.out("{", 1)
77
78 self.out("EncodedMessage buffer;", 2)
79 self.out("Encoder::append(buffer, decoder_magic());", 2)
80 self.out("Encoder::append(buffer, id());", 2)
81 if msg.protected:
82 self.out("Encoder::append(buffer, key());", 2)
83 for i in msg.params:
84 self.out("Encoder::append(buffer, m_{0});".format(i[1]), 2)
85
86 self.out("return buffer;", 2)
87 self.out("}", 1)
88
89 def generate_message(self, msg):
90 self.out("class {0} : public Message {{".format(msg.name))
91 self.out("public:")
92 self.message_create_constructor(msg)
93 self.message_create_std_funcs(msg)
94 self.message_create_encoder(msg)
95 self.out("private:")
96 self.message_create_vars(msg)
97 self.out("};")
98 self.out("")
99
100 def decoder_create_vars(self, messages, offset=0):
101 var_names = set()
102 for (name, params) in messages.items():
103 for i in params:
104 if 'var_{0}'.format(i[1]) not in var_names:
105 self.out("{0} var_{1};".format(i[0], i[1]), offset)
106 var_names.add('var_{0}'.format(i[1]))
107
108 def decoder_decode_message(self, msg, offset=0):
109 params_str = ""
110 if msg.protected:
111 params_str = "secret_key, "
112 for i in msg.params:
113 params_str += "var_{0}, ".format(i[1])
114
115 if len(params_str) > 0:
116 params_str = params_str[:-2]
117 for i in msg.params:
118 self.out(
119 "Encoder::decode(buf, decoded_msg_len, var_{0});".format(i[1]), offset)
120 self.out("return new {0}({1});".format(msg.name, params_str), offset)
121
122 def decoder_create_std_funcs(self, decoder):
123 self.out("int magic() const {{ return {0}; }}".format(
124 decoder.magic), 1)
125
126 def decoder_create_decode(self, decoder):
127 self.out(
128 "std::unique_ptr<Message> decode(const char* buf, size_t size, size_t& decoded_msg_len) override", 1)
129 self.out("{", 1)
130 self.out("int msg_id, decoder_magic;", 2)
131 self.out("size_t saved_dml = decoded_msg_len;", 2)
132 self.out("Encoder::decode(buf, decoded_msg_len, decoder_magic);", 2)
133 self.out("Encoder::decode(buf, decoded_msg_len, msg_id);", 2)
134 self.out("if (magic() != decoder_magic) {", 2)
135 self.out("decoded_msg_len = saved_dml;", 3)
136 self.out("return nullptr;", 3)
137 self.out("}", 2)
138
139 if decoder.protected:
140 self.out("message_key_t secret_key;", 2)
141 self.out("Encoder::decode(buf, decoded_msg_len, secret_key);", 2)
142 self.out("", 0)
143
144 self.decoder_create_vars(decoder.messages, 2)
145
146 unique_msg_id = 1
147 self.out("", 2)
148 self.out("switch(msg_id) {", 2)
149 for (name, params) in decoder.messages.items():
150 self.out("case {0}:".format(unique_msg_id), 2)
151 # Here it doen't need to know the real reply_id, so we can put 0 here.
152 self.decoder_decode_message(
153 Message(name, unique_msg_id, 0, decoder.magic, params, decoder.protected), 3)
154 unique_msg_id += 1
155
156 self.out("default:", 2)
157 self.out("decoded_msg_len = saved_dml;", 3)
158 self.out("return nullptr;", 3)
159 self.out("}", 2)
160 self.out("}", 1)
161 self.out("", 1)
162
163 def decoder_create_handle(self, decoder):
164 self.out("std::unique_ptr<Message> handle(Message& msg) override", 1)
165 self.out("{", 1)
166 self.out("if (magic() != msg.decoder_magic()) {", 2)
167 self.out("return nullptr;", 3)
168 self.out("}", 2)
169
170 unique_msg_id = 1
171 self.out("", 2)
172 self.out("switch(msg.id()) {", 2)
173 for (name, params) in decoder.messages.items():
174 if name in decoder.functions:
175 self.out("case {0}:".format(unique_msg_id), 2)
176 self.out(
177 "return handle(static_cast<{0}&>(msg));".format(name), 3)
178
179 unique_msg_id += 1
180
181 self.out("default:", 2)
182 self.out("return nullptr;", 3)
183 self.out("}", 2)
184 self.out("}", 1)
185 self.out("", 1)
186
187 def decoder_create_virtual_handle(self, decoder):
188 for (accept, ret) in decoder.functions.items():
189 self.out(
190 "virtual std::unique_ptr<Message> handle({0}& msg) {{ return nullptr; }}".format(accept), 1)
191
192 def generate_decoder(self, decoder):
193 self.out("class {0} : public MessageDecoder {{".format(decoder.name))
194 self.out("public:")
195 self.out("{0}() {{}}".format(decoder.name), 1)
196 self.decoder_create_std_funcs(decoder)
197 self.decoder_create_decode(decoder)
198 self.decoder_create_handle(decoder)
199 self.decoder_create_virtual_handle(decoder)
200 self.out("};")
201 self.out("")
202
203 def includes(self):
204 self.out("// Auto generated with utils/ConnectionCompiler")
205 self.out("// See .ipc file")
206 self.out("")
207 self.out("#pragma once")
208 self.out("#include <libipc/Encoder.h>")
209 self.out("#include <libipc/ClientConnection.h>")
210 self.out("#include <libipc/ServerConnection.h>")
211 self.out("#include <libipc/StringEncoder.h>")
212 self.out("#include <libipc/VectorEncoder.h>")
213 self.out("#include <new>")
214 self.out("#include <libg/Rect.h>")
215 self.out("")
216
217 def generate(self, filename, decoders):
218 self.output = open(filename, "w+")
219 self.includes()
220 for decoder in decoders:
221 msgd = {}
222 unique_msg_id = 1
223 for (name, params) in decoder.messages.items():
224 msgd[name] = unique_msg_id
225 unique_msg_id += 1
226
227 for (name, params) in decoder.messages.items():
228 reply_name = decoder.functions.get(name, None)
229 reply_id = -1
230 if reply_name is not None:
231 reply_id = msgd[reply_name]
232 self.generate_message(
233 Message(name, msgd[name], reply_id, decoder.magic, params, decoder.protected))
234
235 self.generate_decoder(decoder)
236 self.output.close()