1 /**
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "dprof/ipc/ipc_unix_socket.h"
17 #include "dprof/ipc/ipc_message.h"
18 #include "dprof/ipc/ipc_message_protocol.h"
19 #include "dprof/storage.h"
20 #include "serializer/serializer.h"
21 #include "utils/logger.h"
22 #include "utils/pandargs.h"
23 #include "utils/span.h"
24 
25 #include <csignal>
26 #include <queue>
27 #include <sys/socket.h>
28 #include <thread>
29 
30 #include "generated/daemon_options.h"
31 
32 namespace ark::dprof {
CheckVersion(const os::unique_fd::UniqueFd &sock)33 bool CheckVersion(const os::unique_fd::UniqueFd &sock)
34 {
35     // Get version
36     ipc::Message msg;
37     if (RecvMessage(sock.Get(), msg) <= 0) {
38         LOG(ERROR, DPROF) << "Cannot read message";
39         return false;
40     }
41     if (msg.GetId() != ipc::Message::Id::VERSION) {
42         LOG(ERROR, DPROF) << "Incorrect first message id, id=" << static_cast<uint32_t>(msg.GetId());
43         return false;
44     }
45     ipc::protocol::Version tmp;
46     if (!serializer::BufferToStruct<ipc::protocol::VERSION_FCOUNT>(msg.GetData(), msg.GetSize(), tmp)) {
47         LOG(ERROR, DPROF) << "Cannot convert data to version message";
48         return false;
49     }
50     if (tmp.version != ipc::protocol::VERSION) {
51         LOG(ERROR, DPROF) << "Incorrect version:" << tmp.version;
52         return false;
53     }
54     return true;
55 }
56 
ProcessingConnect(const os::unique_fd::UniqueFd &sock)57 static std::unique_ptr<AppData> ProcessingConnect(const os::unique_fd::UniqueFd &sock)
58 {
59     if (!CheckVersion(sock)) {
60         return nullptr;
61     }
62 
63     ipc::protocol::AppInfo ipcAppInfo;
64     {
65         // Get app info
66         ipc::Message msg;
67         if (RecvMessage(sock.Get(), msg) <= 0) {
68             LOG(ERROR, DPROF) << "Cannot read message";
69             return nullptr;
70         }
71         if (msg.GetId() != ipc::Message::Id::APP_INFO) {
72             LOG(ERROR, DPROF) << "Incorrect second message id, id=" << static_cast<uint32_t>(msg.GetId());
73             return nullptr;
74         }
75         if (!serializer::BufferToStruct<ipc::protocol::APP_INFO_FCOUNT>(msg.GetData(), msg.GetSize(), ipcAppInfo)) {
76             LOG(ERROR, DPROF) << "Cannot convert data to a app info message";
77             return nullptr;
78         }
79     }
80 
81     // Get features data
82     AppData::FeaturesMap featuresMap;
83     for (;;) {
84         ipc::Message msg;
85         int ret = RecvMessage(sock.Get(), msg);
86         if (ret == 0) {
87             // There are no more messages, the socket is closed
88             break;
89         }
90         if (ret < 0) {
91             LOG(ERROR, DPROF) << "Cannot read a feature data message";
92             return nullptr;
93         }
94 
95         ipc::protocol::FeatureData tmp;
96         if (!serializer::BufferToStruct<ipc::protocol::FEATURE_DATA_FCOUNT>(msg.GetData(), msg.GetSize(), tmp)) {
97             LOG(ERROR, DPROF) << "Cannot convert data to a feature data";
98             return nullptr;
99         }
100 
101         featuresMap.emplace(std::pair(std::move(tmp.name), std::move(tmp.data)));
102     }
103 
104     return AppData::CreateByParams(ipcAppInfo.appName, ipcAppInfo.hash, ipcAppInfo.pid, std::move(featuresMap));
105 }
106 
107 class Worker {
108 public:
EnqueueClientSocket(os::unique_fd::UniqueFd clientSock)109     void EnqueueClientSocket(os::unique_fd::UniqueFd clientSock)
110     {
111         os::memory::LockHolder lock(queueLock_);
112         queue_.push(std::move(clientSock));
113         cond_.Signal();
114     }
115 
Start(AppDataStorage *storage)116     void Start(AppDataStorage *storage)
117     {
118         done_ = false;
119         thread_ = std::thread([this](AppDataStorage *strg) { DoRun(strg); }, storage);
120     }
121 
Stop()122     void Stop()
123     {
124         os::memory::LockHolder lock(queueLock_);
125         done_ = true;
126         cond_.Signal();
127         thread_.join();
128     }
129 
DoRun(AppDataStorage *storage)130     void DoRun(AppDataStorage *storage)
131     {
132         while (!done_) {
133             os::unique_fd::UniqueFd clientSock;
134             {
135                 os::memory::LockHolder lock(queueLock_);
136                 while (queue_.empty() && !done_) {
137                     cond_.Wait(&queueLock_);
138                 }
139                 if (done_) {
140                     break;
141                 }
142 
143                 clientSock = std::move(queue_.front());
144                 queue_.pop();
145             }
146 
147             auto appData = ProcessingConnect(clientSock);
148             if (!appData) {
149                 LOG(ERROR, DPROF) << "Connection cannot be processed";
150                 continue;
151             }
152 
153             storage->SaveAppData(*appData);
154         }
155     }
156 
157 private:
158     std::thread thread_;
159     os::memory::Mutex queueLock_;
160     std::queue<os::unique_fd::UniqueFd> queue_;
161     os::memory::ConditionVariable cond_ GUARDED_BY(queueLock_);
162     bool done_ = false;
163 };
164 
165 class ArgsParser {
166 public:
Parse(ark::Span<const char *> args)167     bool Parse(ark::Span<const char *> args)
168     {
169         options_.AddOptions(&parser_);
170         if (!parser_.Parse(args.Size(), args.Data())) {
171             std::cerr << parser_.GetErrorString();
172             return false;
173         }
174         auto err = options_.Validate();
175         if (err) {
176             std::cerr << err.value().GetMessage() << std::endl;
177             return false;
178         }
179         if (options_.GetStorageDir().empty()) {
180             std::cerr << "Option \"storage-dir\" is not set" << std::endl;
181             return false;
182         }
183         return true;
184     }
185 
GetOptionos() const186     const Options &GetOptionos() const
187     {
188         return options_;
189     }
190 
Help() const191     void Help() const
192     {
193         std::cerr << "Usage: " << appName_ << " [OPTIONS]" << std::endl;
194         std::cerr << "optional arguments:" << std::endl;
195         std::cerr << parser_.GetHelpString() << std::endl;
196     }
197 
198 private:
199     std::string appName_;
200     PandArgParser parser_;
201     Options options_ {""};
202 };
203 
204 static bool g_done = false;
205 
SignalHandler(int sig)206 static void SignalHandler(int sig)
207 {
208     if (sig == SIGINT || sig == SIGHUP || sig == SIGTERM) {
209         g_done = true;
210     }
211 }
212 
SetupSignals()213 static void SetupSignals()
214 {
215     struct sigaction sa {};
216     PLOG_IF(::memset_s(&sa, sizeof(sa), 0, sizeof(sa)) != 0, FATAL, DPROF) << "memset_s failed";
217     sa.sa_handler = SignalHandler;  // NOLINT(cppcoreguidelines-pro-type-union-access)
218     PLOG_IF(::sigemptyset(&sa.sa_mask) == -1, FATAL, DPROF) << "sigemptyset() failed";
219 
220     PLOG_IF(::sigaction(SIGINT, &sa, nullptr) == -1, FATAL, DPROF) << "sigaction(SIGINT) failed";
221     PLOG_IF(::sigaction(SIGHUP, &sa, nullptr) == -1, FATAL, DPROF) << "sigaction(SIGHUP) failed";
222     PLOG_IF(::sigaction(SIGTERM, &sa, nullptr) == -1, FATAL, DPROF) << "sigaction(SIGTERM) failed";
223 }
224 
Main(ark::Span<const char *> args)225 static int Main(ark::Span<const char *> args)
226 {
227     const int maxPendingConnectionsQueue = 32;
228 
229     ArgsParser parser;
230     if (!parser.Parse(args)) {
231         parser.Help();
232         return -1;
233     }
234     const Options &options = parser.GetOptionos();
235 
236     Logger::InitializeStdLogging(Logger::LevelFromString(options.GetLogLevel()), ark::LOGGER_COMPONENT_MASK_ALL);
237 
238     SetupSignals();
239 
240     auto storage = AppDataStorage::Create(options.GetStorageDir(), true);
241     if (!storage) {
242         LOG(FATAL, DPROF) << "Cannot init storage";
243         return -1;
244     }
245 
246     // Create server socket
247     os::unique_fd::UniqueFd sock(ipc::CreateUnixServerSocket(maxPendingConnectionsQueue));
248     if (!sock.IsValid()) {
249         LOG(FATAL, DPROF) << "Cannot create socket";
250         return -1;
251     }
252 
253     Worker worker;
254     worker.Start(storage.get());
255 
256     LOG(INFO, DPROF) << "Daemon is ready for connections";
257     // Main loop
258     while (!g_done) {
259         os::unique_fd::UniqueFd clientSock(::accept4(sock.Get(), nullptr, nullptr, SOCK_CLOEXEC));
260         if (!clientSock.IsValid()) {
261             if (errno == EINTR) {
262                 continue;
263             }
264             PLOG(FATAL, DPROF) << "accept() failed";
265             return -1;
266         }
267         worker.EnqueueClientSocket(std::move(clientSock));
268     }
269     LOG(INFO, DPROF) << "Daemon has received an end signal and stops";
270     worker.Stop();
271     LOG(INFO, DPROF) << "Daemon is stopped";
272     return 0;
273 }
274 }  // namespace ark::dprof
275 
main(int argc, const char *argv[])276 int main(int argc, const char *argv[])
277 {
278     ark::Span<const char *> args(argv, argc);
279     return ark::dprof::Main(args);
280 }
281