1 /*
2  * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <malloc.h>
16 #include "hook_socket_client.h"
17 
18 #include "common.h"
19 #include "hook_common.h"
20 #include "unix_socket_client.h"
21 #include "share_memory_allocator.h"
22 #include "logging.h"
23 #include "sampling.h"
24 #include <unistd.h>
25 #include <sys/stat.h>
26 #include <sys/types.h>
27 #include <fcntl.h>
28 #include <cstdio>
29 #include <cstring>
30 #include <iostream>
31 
32 namespace {
33 constexpr int FLUSH_FLAG = 20;
34 std::atomic<uint64_t> g_flushCount = 0;
35 std::atomic<bool> g_disableHook = false;
36 constexpr uint32_t MEMCHECK_DETAILINFO_MAXSIZE = 102400;
37 
38 struct OptArg {
39     size_t pos;
40     char *buf;
41 };
42 
43 } // namespace
44 
45 
GetRealTime()46 static std::string GetRealTime()
47 {
48     time_t now = time(nullptr);
49     tm tm;
50     const int timeLength = 64;
51     char stampStr[timeLength] = {0};
52 
53     if (localtime_r(&now, &tm) == nullptr || strftime(stampStr, timeLength, "%Y/%m/%d %H:%M:%S", &tm) == 0) {
54         return "error time format!";
55     }
56     return std::string(stampStr);
57 }
58 
NmdWriteStat(void *arg, const char *buf)59 static void NmdWriteStat(void *arg, const char *buf)
60 {
61     struct OptArg *opt = static_cast<struct OptArg*>(arg);
62     std::string getNmdTime = std::to_string(getpid()) + " " + GetRealTime() + "\n";
63     size_t nmdTimeLen = getNmdTime.size();
64     if (strncpy_s(opt->buf + opt->pos, MEMCHECK_DETAILINFO_MAXSIZE - opt->pos,
65                   getNmdTime.c_str(), nmdTimeLen) != EOK) {
66         return;
67     }
68     opt->pos += nmdTimeLen;
69 
70     size_t len = strlen(buf);
71     if (len + opt->pos + 1 > MEMCHECK_DETAILINFO_MAXSIZE) {
72         return;
73     }
74     if (strncpy_s(opt->buf + opt->pos, MEMCHECK_DETAILINFO_MAXSIZE - opt->pos, buf, len) != EOK) {
75         return;
76     }
77     opt->pos += len;
78 }
79 
HookSocketClient(int pid, ClientConfig *config, Sampling *sampler, void (*disableHookCallback)())80 HookSocketClient::HookSocketClient(int pid, ClientConfig *config, Sampling *sampler, void (*disableHookCallback)())
81     : pid_(pid), config_(config), sampler_(sampler), disableHookCallback_(disableHookCallback)
82 {
83     smbFd_ = 0;
84     eventFd_ = 0;
85     unixSocketClient_ = nullptr;
86     serviceName_ = "HookService";
87     Connect(DEFAULT_UNIX_SOCKET_HOOK_FULL_PATH);
88 }
89 
~HookSocketClient()90 HookSocketClient::~HookSocketClient()
91 {
92     if (stackWriter_) {
93         stackWriter_->Flush();
94     }
95     unixSocketClient_ = nullptr;
96     stackWriter_ = nullptr;
97 }
98 
Connect(const std::string addrname)99 bool HookSocketClient::Connect(const std::string addrname)
100 {
101     if (unixSocketClient_ != nullptr) {
102         return false;
103     }
104     unixSocketClient_ = std::make_shared<UnixSocketClient>();
105     if (!unixSocketClient_->Connect(addrname, *this)) {
106         unixSocketClient_ = nullptr;
107         return false;
108     }
109 
110     unixSocketClient_->SendHookConfig(reinterpret_cast<uint8_t *>(&pid_), sizeof(pid_));
111     return true;
112 }
113 
ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)114 bool HookSocketClient::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
115 {
116     if (size != sizeof(ClientConfig)) {
117         return true;
118     }
119     *config_ = *reinterpret_cast<ClientConfig *>(const_cast<int8_t*>(buf));
120     config_->maxStackDepth  = config_->maxStackDepth > MAX_UNWIND_DEPTH ? MAX_UNWIND_DEPTH : config_->maxStackDepth;
121     std::string configStr = config_->ToString();
122     sampler_->InitSampling(config_->sampleInterval);
123     smbFd_ = context.ReceiveFileDiscriptor();
124     eventFd_ = context.ReceiveFileDiscriptor();
125     std::string smbName = "hooknativesmb_" + std::to_string(pid_);
126     stackWriter_ = std::make_shared<StackWriter>(smbName, config_->shareMemorySize,
127         smbFd_, eventFd_, config_->isBlocked);
128     struct mallinfo2 mi = mallinfo2();
129     nmdType_ = config_->nmdType;
130     if (nmdType_ == 0) {
131         SendNmdInfo();
132     }
133     return true;
134 }
135 
SendStack(const void* data, size_t size)136 bool HookSocketClient::SendStack(const void* data, size_t size)
137 {
138     if (stackWriter_ == nullptr || unixSocketClient_ == nullptr) {
139         return false;
140     }
141 
142     if (!unixSocketClient_->SendHeartBeat()) {
143         return false;
144     }
145 
146     stackWriter_->WriteTimeout(data, size);
147     stackWriter_->Flush();
148 
149     return true;
150 }
151 
SendStackWithPayload(const void* data, size_t size, const void* payload, size_t payloadSize)152 bool HookSocketClient::SendStackWithPayload(const void* data, size_t size, const void* payload,
153     size_t payloadSize)
154 {
155     if (stackWriter_ == nullptr || unixSocketClient_ == nullptr || g_disableHook) {
156         return false;
157     } else if (unixSocketClient_->GetClientState() == CLIENT_STAT_THREAD_EXITED) {
158         DisableHook();
159         return false;
160     }
161 
162     bool ret = stackWriter_->WriteWithPayloadTimeout(data, size, payload, payloadSize,
163                                                      std::bind(&HookSocketClient::PeerIsConnected, this));
164     if (!ret && config_->isBlocked) {
165         DisableHook();
166         return false;
167     }
168     ++g_flushCount;
169     if (g_flushCount % FLUSH_FLAG == 0) {
170         stackWriter_->Flush();
171     }
172     return true;
173 }
174 
Flush()175 void HookSocketClient::Flush()
176 {
177     if (stackWriter_ == nullptr || unixSocketClient_ == nullptr) {
178         return;
179     }
180     stackWriter_->Flush();
181 }
182 
DisableHook()183 void HookSocketClient::DisableHook()
184 {
185     bool expected = false;
186     if (g_disableHook.compare_exchange_strong(expected, true, std::memory_order_release, std::memory_order_relaxed)) {
187         HILOG_BASE_INFO(LOG_CORE, "%s", __func__);
188         if (disableHookCallback_) {
189             disableHookCallback_();
190         }
191     }
192 }
193 
PeerIsConnected()194 bool HookSocketClient::PeerIsConnected()
195 {
196     return !unixSocketClient_->IsConnected();
197 }
198 
SendNmdInfo()199 bool HookSocketClient::SendNmdInfo()
200 {
201     if (!config_->printNmd) {
202         return false;
203     }
204     void* nmdBuf = malloc(MEMCHECK_DETAILINFO_MAXSIZE);
205     if (nmdBuf == nullptr) {
206         return false;
207     }
208     struct OptArg opt = {0, reinterpret_cast<char*>(nmdBuf) };
209     malloc_stats_print(NmdWriteStat, &opt, "a");
210     StackRawData rawdata = {{{{0}}}};
211     rawdata.type = NMD_MSG;
212     if (stackWriter_) {
213         stackWriter_->WriteWithPayloadTimeout(&rawdata, sizeof(BaseStackRawData),
214                                               reinterpret_cast<int8_t*>(opt.buf), strlen(opt.buf) + 1,
215                                               std::bind(&HookSocketClient::PeerIsConnected, this));
216     }
217     free(nmdBuf);
218     return true;
219 }
220 
SendEndMsg()221 bool HookSocketClient::SendEndMsg()
222 {
223     StackRawData rawdata = {{{{0}}}};
224     rawdata.type = END_MSG;
225     if (stackWriter_) {
226         stackWriter_->WriteTimeout(&rawdata, sizeof(BaseStackRawData));
227     }
228     return true;
229 }