1/*
2 * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "ipc_generator_impl.h"
17
18IpcGeneratorImpl::IpcGeneratorImpl() {}
19
20IpcGeneratorImpl::~IpcGeneratorImpl() {}
21
22namespace {
23const std::string BASE_HEADER_STRING = R"(
24#pragma once
25
26#include "#HEAD_FILE_NAME#.pb.h"
27#include "service_base.h"
28#include <cstdint>
29#include <mutex>
30
31class SocketContext;
32class UnixSocketClient;
33
34#PROTOCOL_ENUM#
35
36class #SERVICE_CLASS_NAME#:public ServiceBase
37{
38public:
39    #SERVICE_CLASS_NAME#();
40    bool ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size) override;
41#RESPONSE_DEFINE#
42};
43
44class #CLIENT_CLASS_NAME#:public ServiceBase
45{
46public:
47    #CLIENT_CLASS_NAME#();
48
49    std::shared_ptr<UnixSocketClient> unixSocketClient_;
50    bool Connect(const std::string addrname);
51    bool ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size) override;
52    google::protobuf::Message *presponse;
53    uint32_t waitingFor;
54#VIRTUAL_RESPONSE_FUNC#
55};
56)";
57
58const std::string BASE_SOURCE_STRING = R"(
59#include "#HEAD_FILE_NAME#.ipc.h"
60#include "#HEAD_FILE_NAME#.pb.h"
61#include "socket_context.h"
62#include "unix_socket_client.h"
63#include "unix_socket_server.h"
64#include <unistd.h>
65
66namespace {
67    constexpr uint32_t WAIT_FOR_EVER = 24 * 60 * 60 * 1000;
68}
69
70#SERVICE_CLASS_NAME#::#SERVICE_CLASS_NAME#()
71{
72    serviceName_ = "#SERVICE_NAME#";
73}
74
75#RESPONSE_IMPLEMENT#
76
77#SERVICE_PROTOCOL_PROC_FUNC#
78bool #SERVICE_CLASS_NAME#::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
79{
80    switch (pnum) {
81#SERVICE_PROTOCOL_PROC#
82    }
83    return false;
84}
85
86#CLIENT_CLASS_NAME#::#CLIENT_CLASS_NAME#()
87{
88    unixSocketClient_ = nullptr;
89    serviceName_ = "#SERVICE_NAME#";
90}
91bool #CLIENT_CLASS_NAME#::Connect(const std::string addrname)
92{
93    if (unixSocketClient_ != nullptr) {
94        return false;
95    }
96    unixSocketClient_ = std::make_shared<UnixSocketClient>();
97    if (!unixSocketClient_->Connect(addrname, *this)) {
98        printf("Socket Connect failed\n");
99        unixSocketClient_ = nullptr;
100        return false;
101    }
102    return true;
103}
104
105#CLIENT_SEND_REQUEST_PROC_FUNC#
106
107#CLIENT_SEND_PROTOCOL_PROC_FUNC#
108bool #CLIENT_CLASS_NAME#::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
109{
110    switch (pnum) {
111#CLIENT_PROTOCOL_PROC#
112    }
113    if (waitingFor == pnum) {
114        waitingFor = -1;
115        mWait_.unlock();
116    }
117    return false;
118}
119)";
120
121std::string SwapName(std::string s)
122{
123    std::string ret = "";
124    bool b = true;
125    for (size_t i = 0; i < s.length(); i++) {
126        char c = s[i];
127        if (c == '_') {
128            b = true;
129        } else if (b && c >= 'a' && c <= 'z') {
130            ret += (c + 'A' - 'a');
131            b = false;
132        } else {
133            ret += c;
134        }
135    }
136    return ret;
137}
138std::string ReplaceStr(const std::string& base, const std::string& _from, const std::string& _to)
139{
140    std::string ret = base;
141    while (true) {
142        size_t pos = ret.find(_from, 0);
143        if (pos == std::string::npos) {
144            break;
145        }
146        ret.replace(pos, _from.length(), _to);
147    }
148    return ret;
149}
150} // namespace
151
152std::string IpcGeneratorImpl::SetNames(std::string fileName, std::string packageName)
153{
154    fileName_ = fileName;
155    packageName_ = packageName + "::";
156    headFileName_ = "";
157
158    for (size_t i = 0; i < fileName.length(); i++) {
159        if (fileName.c_str()[i] == '.') {
160            break;
161        }
162        headFileName_ += fileName.c_str()[i];
163    }
164    baseName_ = SwapName(headFileName_);
165
166    serviceCount_ = 0;
167
168    serviceList_.clear();
169    enumMessageDict_.clear();
170
171    return headFileName_;
172}
173
174bool IpcGeneratorImpl::AddService(std::string serviceName)
175{
176    for (int i = 0; i < serviceCount_; i++) {
177        if (serviceList_[i].serviceName_ == serviceName) {
178            return false;
179        }
180    }
181    serviceList_[serviceCount_].serviceName_ = serviceName;
182    serviceCount_++;
183    return true;
184}
185
186bool IpcGeneratorImpl::AddServiceMethod(std::string serviceName,
187                                        std::string methodName,
188                                        std::string requestName,
189                                        std::string responseName)
190{
191    for (int i = 0; i < serviceCount_; i++) {
192        if (serviceList_[i].serviceName_ == serviceName) {
193            return serviceList_[i].AddMethod(methodName, requestName, responseName);
194        }
195    }
196    return false;
197}
198
199void IpcGeneratorImpl::GenerateHeader(std::string& header_str)
200{
201    for (int i = 0; i < serviceCount_; i++) {
202        std::string server_class_name = serviceList_[i].serviceName_ + "Server";
203        header_str = ReplaceStr(header_str, "#SERVICE_CLASS_NAME#", server_class_name);
204
205        std::string tmp1 = "";
206        std::string tmp2 = "";
207        for (int j = 0; j < serviceList_[i].methodCount_; j++) {
208            tmp1 += "\tvirtual bool " + serviceList_[i].methodList_[j] + "(SocketContext &context," + packageName_ +
209                    serviceList_[i].requestList_[j] + " &request," + packageName_ + serviceList_[i].responseList_[j] +
210                    " &response);\n";
211
212            tmp2 += "\tbool SendResponse" + serviceList_[i].responseList_[j] + "(SocketContext &context," +
213                    packageName_ + serviceList_[i].responseList_[j] + " &response);\n";
214        }
215        tmp1 += "\n" + tmp2;
216        header_str = ReplaceStr(header_str, "#RESPONSE_DEFINE#", tmp1);
217
218        std::string client_class_name = serviceList_[i].serviceName_ + "Client";
219        header_str = ReplaceStr(header_str, "#CLIENT_CLASS_NAME#", client_class_name);
220
221        tmp1 = "";
222        for (int j = 0; j < serviceList_[i].methodCount_; j++) {
223            tmp1 += "\tbool " + serviceList_[i].methodList_[j] + "(" + packageName_ + serviceList_[i].requestList_[j];
224            tmp1 += " &request," + packageName_ + serviceList_[i].responseList_[j];
225            tmp1 += " &response,uint32_t timeout_ms=5000);\n";
226            tmp1 += "\tbool " + serviceList_[i].methodList_[j] + "(" + packageName_ + serviceList_[i].requestList_[j];
227            tmp1 += " &request);\n";
228        }
229        tmp1 += "\n";
230        for (int j = 0; j < serviceList_[i].methodCount_; j++) {
231            tmp1 += "\tvirtual bool On" + serviceList_[i].responseList_[j] + "(SocketContext &context," + packageName_;
232            tmp1 += serviceList_[i].responseList_[j] + " &response);\n";
233        }
234
235        header_str = ReplaceStr(header_str, "#VIRTUAL_RESPONSE_FUNC#", tmp1);
236    }
237}
238
239std::string IpcGeneratorImpl::GenHeader()
240{
241    std::string header_str = BASE_HEADER_STRING;
242    std::string tmp1;
243    header_str = ReplaceStr(header_str, "#HEAD_FILE_NAME#", headFileName_);
244    const int numTwo = 2;
245
246    if (serviceCount_ > 0) {
247        tmp1 = "enum {\n";
248        for (int i = 0; i < serviceCount_; i++) {
249            for (int j = 0; j < serviceList_[i].methodCount_; j++) {
250                tmp1 += "\tIpcProtocol" + baseName_ + serviceList_[i].requestList_[j];
251                tmp1 += "=" + std::to_string(j * numTwo) + ",\n";
252                tmp1 += "\tIpcProtocol" + baseName_ + serviceList_[i].responseList_[j];
253                tmp1 += "=" + std::to_string(j * numTwo + 1) + ",\n";
254            }
255        }
256        tmp1 += "};";
257    } else {
258        tmp1 = "";
259    }
260    header_str = ReplaceStr(header_str, "#PROTOCOL_ENUM#", tmp1);
261
262    GenerateHeader(header_str);
263    header_str = ReplaceStr(header_str, "\t", "    ");
264    return header_str;
265}
266
267namespace {
268const std::string SEND_RESPONSE_IMPL_STRING = R"(
269bool #SERVER_CLASS_NAME#::SendResponse#RESPONSE_NAME#(SocketContext &context,
270                                                      #PACKAGE_NAME##RESPONSE_NAME# &response) {
271    context.SendProtobuf(#ENUM_STR#, response);
272    return false;
273}
274)";
275}
276std::string IpcGeneratorImpl::GenSendResponseImpl(int servicep, const std::string& server_class_name)
277{
278    std::string ret = "";
279    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
280        std::string enum_str = "IpcProtocol" + baseName_ + serviceList_[servicep].responseList_[j];
281        std::string tmp = ReplaceStr(SEND_RESPONSE_IMPL_STRING, "#SERVER_CLASS_NAME#", server_class_name);
282        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
283        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
284        tmp = ReplaceStr(tmp, "#ENUM_STR#", enum_str);
285        ret += tmp;
286    }
287    return ret;
288}
289namespace {
290const std::string ON_RESPONSE_IMPL_STRING = R"(
291bool #CLIENT_CLASS_NAME#::On#RESPONSE_NAME#(SocketContext &context, #PACKAGE_NAME##RESPONSE_NAME# &response) {
292    return false;
293}
294)";
295}
296std::string IpcGeneratorImpl::GenOnResponseImpl(int servicep, const std::string& client_class_name)
297{
298    std::string ret = "";
299    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
300        std::string tmp = ReplaceStr(ON_RESPONSE_IMPL_STRING, "#CLIENT_CLASS_NAME#", client_class_name);
301        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
302        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
303        ret += tmp;
304    }
305    return ret;
306}
307namespace {
308const std::string SERVICE_CALL_IMPL_STRING = R"(
309bool #SERVER_CLASS_NAME#::#METHOD_NAME#(SocketContext &context,
310                                        #PACKAGE_NAME##REQUEST_NAME# &request,
311                                        #PACKAGE_NAME##RESPONSE_NAME# &response) {
312    return false;
313}
314)";
315}
316std::string IpcGeneratorImpl::GenServiceCallImpl(int servicep, const std::string& server_class_name)
317{
318    std::string ret = "";
319    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
320        std::string tmp = ReplaceStr(SERVICE_CALL_IMPL_STRING, "#SERVER_CLASS_NAME#", server_class_name);
321        tmp = ReplaceStr(tmp, "#SERVER_CLASS_NAME#", server_class_name);
322        tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
323        tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
324        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
325        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
326        ret += tmp;
327    }
328    return ret;
329}
330namespace {
331const std::string CLIENT_PROC_IMPL_STRING = R"(
332    case IpcProtocol#BASE_NAME##REQUEST_NAME#:{
333        #PACKAGE_NAME##REQUEST_NAME# request;
334        #PACKAGE_NAME##RESPONSE_NAME# response;
335        request.ParseFromArray(buf, size);
336        if (#METHOD_NAME#(context, request, response)) {
337            context.SendProtobuf(IpcProtocol#BASE_NAME##RESPONSE_NAME#, response);
338        }
339    }
340        break;
341)";
342const std::string CLIENT_PROC_NOTIFYRESULT_STRING = R"(
343    case IpcProtocol#BASE_NAME##REQUEST_NAME#:{
344        #PACKAGE_NAME##REQUEST_NAME# request;
345        #PACKAGE_NAME##RESPONSE_NAME# response;
346        request.ParseFromArray(buf, size);
347        #METHOD_NAME#(context, request, response);
348    }
349        break;
350)";
351}
352std::string IpcGeneratorImpl::GenClientProcImpl(int servicep)
353{
354    std::string ret = "";
355    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
356        std::string tmp = ReplaceStr(CLIENT_PROC_IMPL_STRING, "#BASE_NAME#", baseName_);
357        if (serviceList_[servicep].methodList_[j] == "NotifyResult") {
358            tmp = ReplaceStr(CLIENT_PROC_NOTIFYRESULT_STRING, "#BASE_NAME#", baseName_);
359        }
360        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
361        tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
362        tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
363        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
364        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
365        ret += tmp;
366    }
367    return ret;
368}
369namespace {
370const std::string CLIENT_REQUEST_IMPL_STRING = R"(
371bool #CLIENT_CLASS_NAME#::#METHOD_NAME#(#PACKAGE_NAME##REQUEST_NAME# &request,
372                                        #PACKAGE_NAME##RESPONSE_NAME# &response,
373                                        uint32_t timeout_ms)
374{
375    mWait_.lock();
376    if (timeout_ms<=0) {
377        timeout_ms=WAIT_FOR_EVER;
378    }
379    waitingFor=IpcProtocol#BASE_NAME##RESPONSE_NAME#;
380    presponse=&response;
381    if (unixSocketClient_!=nullptr) {
382        unixSocketClient_->SendProtobuf(IpcProtocol#BASE_NAME##REQUEST_NAME#, request);
383    }
384    if (mWait_.try_lock_for(std::chrono::milliseconds(timeout_ms))) {
385        mWait_.unlock();
386        return true;
387    }
388    waitingFor=-1;
389    mWait_.unlock();
390    return false;
391}
392bool #CLIENT_CLASS_NAME#::#METHOD_NAME#(#PACKAGE_NAME##REQUEST_NAME# &request)
393{
394    unixSocketClient_->SendProtobuf(IpcProtocol#BASE_NAME##REQUEST_NAME#, request);
395    return true;
396}
397)";
398}
399std::string IpcGeneratorImpl::GenClientRequestImpl(int servicep, const std::string& client_class_name)
400{
401    std::string ret = "";
402    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
403        std::string tmp = ReplaceStr(CLIENT_REQUEST_IMPL_STRING, "#CLIENT_CLASS_NAME#", client_class_name);
404        tmp = ReplaceStr(tmp, "#METHOD_NAME#", serviceList_[servicep].methodList_[j]);
405        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
406        tmp = ReplaceStr(tmp, "#REQUEST_NAME#", serviceList_[servicep].requestList_[j]);
407        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
408        tmp = ReplaceStr(tmp, "#BASE_NAME#", baseName_);
409        ret += tmp;
410    }
411    return ret;
412}
413namespace {
414const std::string SERVICE_PROC_IMPL_STRING = R"(
415    case IpcProtocol#BASE_NAME##RESPONSE_NAME#:
416        {
417            if (waitingFor==pnum) {
418                presponse->ParseFromArray(buf, size);
419            }
420            else {
421                #PACKAGE_NAME##RESPONSE_NAME# response#NUM#;
422                response#NUM#.ParseFromArray(buf, size);
423                On#RESPONSE_NAME#(context, response#NUM#);
424            }
425        }
426        break;
427)";
428}
429std::string IpcGeneratorImpl::GenServiceProcImpl(int servicep)
430{
431    std::string ret = "";
432    for (int j = 0; j < serviceList_[servicep].methodCount_; j++) {
433        std::string tmp = ReplaceStr(SERVICE_PROC_IMPL_STRING, "#BASE_NAME#", baseName_);
434        tmp = ReplaceStr(tmp, "#RESPONSE_NAME#", serviceList_[servicep].responseList_[j]);
435        tmp = ReplaceStr(tmp, "#PACKAGE_NAME#", packageName_);
436        tmp = ReplaceStr(tmp, "#NUM#", std::to_string(j + 1));
437
438        ret += tmp;
439    }
440    return ret;
441}
442
443std::string IpcGeneratorImpl::GenSource()
444{
445    std::string source_str = BASE_SOURCE_STRING;
446
447    source_str = ReplaceStr(source_str, "#HEAD_FILE_NAME#", headFileName_);
448
449    for (int i = 0; i < serviceCount_; i++) {
450        std::string server_class_name = serviceList_[i].serviceName_ + "Server";
451        source_str = ReplaceStr(source_str, "#SERVICE_CLASS_NAME#", server_class_name);
452        source_str = ReplaceStr(source_str, "#SERVICE_NAME#", serviceList_[i].serviceName_);
453        std::string client_class_name = serviceList_[i].serviceName_ + "Client";
454        source_str = ReplaceStr(source_str, "#CLIENT_CLASS_NAME#", client_class_name);
455
456        source_str = ReplaceStr(source_str, "#RESPONSE_IMPLEMENT#", GenSendResponseImpl(i, server_class_name));
457        source_str = ReplaceStr(source_str, "#CLIENT_SEND_REQUEST_PROC_FUNC#", GenOnResponseImpl(i, client_class_name));
458
459        source_str = ReplaceStr(source_str, "#SERVICE_PROTOCOL_PROC_FUNC#", GenServiceCallImpl(i, server_class_name));
460        source_str = ReplaceStr(source_str, "#SERVICE_PROTOCOL_PROC#", GenClientProcImpl(i));
461        source_str = ReplaceStr(source_str, "#SERVICE_NAME#", serviceList_[i].serviceName_);
462
463        source_str = ReplaceStr(source_str, "#CLIENT_PROTOCOL_PROC#", GenServiceProcImpl(i));
464        source_str =
465            ReplaceStr(source_str, "#CLIENT_SEND_PROTOCOL_PROC_FUNC#", GenClientRequestImpl(i, client_class_name));
466    }
467
468    source_str = ReplaceStr(source_str, "\t", "    ");
469    return source_str;
470}
471