1/*
2 * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15#include <malloc.h>
16#include "hook_socket_client.h"
17
18#include "common.h"
19#include "hook_common.h"
20#include "unix_socket_client.h"
21#include "share_memory_allocator.h"
22#include "logging.h"
23#include "sampling.h"
24#include <unistd.h>
25#include <sys/stat.h>
26#include <sys/types.h>
27#include <fcntl.h>
28#include <cstdio>
29#include <cstring>
30#include <iostream>
31
32namespace {
33constexpr int FLUSH_FLAG = 20;
34std::atomic<uint64_t> g_flushCount = 0;
35std::atomic<bool> g_disableHook = false;
36constexpr uint32_t MEMCHECK_DETAILINFO_MAXSIZE = 102400;
37
38struct OptArg {
39    size_t pos;
40    char *buf;
41};
42
43} // namespace
44
45
46static std::string GetRealTime()
47{
48    time_t now = time(nullptr);
49    tm tm;
50    const int timeLength = 64;
51    char stampStr[timeLength] = {0};
52
53    if (localtime_r(&now, &tm) == nullptr || strftime(stampStr, timeLength, "%Y/%m/%d %H:%M:%S", &tm) == 0) {
54        return "error time format!";
55    }
56    return std::string(stampStr);
57}
58
59static void NmdWriteStat(void *arg, const char *buf)
60{
61    struct OptArg *opt = static_cast<struct OptArg*>(arg);
62    std::string getNmdTime = std::to_string(getpid()) + " " + GetRealTime() + "\n";
63    size_t nmdTimeLen = getNmdTime.size();
64    if (strncpy_s(opt->buf + opt->pos, MEMCHECK_DETAILINFO_MAXSIZE - opt->pos,
65                  getNmdTime.c_str(), nmdTimeLen) != EOK) {
66        return;
67    }
68    opt->pos += nmdTimeLen;
69
70    size_t len = strlen(buf);
71    if (len + opt->pos + 1 > MEMCHECK_DETAILINFO_MAXSIZE) {
72        return;
73    }
74    if (strncpy_s(opt->buf + opt->pos, MEMCHECK_DETAILINFO_MAXSIZE - opt->pos, buf, len) != EOK) {
75        return;
76    }
77    opt->pos += len;
78}
79
80HookSocketClient::HookSocketClient(int pid, ClientConfig *config, Sampling *sampler, void (*disableHookCallback)())
81    : pid_(pid), config_(config), sampler_(sampler), disableHookCallback_(disableHookCallback)
82{
83    smbFd_ = 0;
84    eventFd_ = 0;
85    unixSocketClient_ = nullptr;
86    serviceName_ = "HookService";
87    Connect(DEFAULT_UNIX_SOCKET_HOOK_FULL_PATH);
88}
89
90HookSocketClient::~HookSocketClient()
91{
92    if (stackWriter_) {
93        stackWriter_->Flush();
94    }
95    unixSocketClient_ = nullptr;
96    stackWriter_ = nullptr;
97}
98
99bool HookSocketClient::Connect(const std::string addrname)
100{
101    if (unixSocketClient_ != nullptr) {
102        return false;
103    }
104    unixSocketClient_ = std::make_shared<UnixSocketClient>();
105    if (!unixSocketClient_->Connect(addrname, *this)) {
106        unixSocketClient_ = nullptr;
107        return false;
108    }
109
110    unixSocketClient_->SendHookConfig(reinterpret_cast<uint8_t *>(&pid_), sizeof(pid_));
111    return true;
112}
113
114bool HookSocketClient::ProtocolProc(SocketContext &context, uint32_t pnum, const int8_t *buf, const uint32_t size)
115{
116    if (size != sizeof(ClientConfig)) {
117        return true;
118    }
119    *config_ = *reinterpret_cast<ClientConfig *>(const_cast<int8_t*>(buf));
120    config_->maxStackDepth  = config_->maxStackDepth > MAX_UNWIND_DEPTH ? MAX_UNWIND_DEPTH : config_->maxStackDepth;
121    std::string configStr = config_->ToString();
122    sampler_->InitSampling(config_->sampleInterval);
123    smbFd_ = context.ReceiveFileDiscriptor();
124    eventFd_ = context.ReceiveFileDiscriptor();
125    std::string smbName = "hooknativesmb_" + std::to_string(pid_);
126    stackWriter_ = std::make_shared<StackWriter>(smbName, config_->shareMemorySize,
127        smbFd_, eventFd_, config_->isBlocked);
128    struct mallinfo2 mi = mallinfo2();
129    nmdType_ = config_->nmdType;
130    if (nmdType_ == 0) {
131        SendNmdInfo();
132    }
133    return true;
134}
135
136bool HookSocketClient::SendStack(const void* data, size_t size)
137{
138    if (stackWriter_ == nullptr || unixSocketClient_ == nullptr) {
139        return false;
140    }
141
142    if (!unixSocketClient_->SendHeartBeat()) {
143        return false;
144    }
145
146    stackWriter_->WriteTimeout(data, size);
147    stackWriter_->Flush();
148
149    return true;
150}
151
152bool HookSocketClient::SendStackWithPayload(const void* data, size_t size, const void* payload,
153    size_t payloadSize)
154{
155    if (stackWriter_ == nullptr || unixSocketClient_ == nullptr || g_disableHook) {
156        return false;
157    } else if (unixSocketClient_->GetClientState() == CLIENT_STAT_THREAD_EXITED) {
158        DisableHook();
159        return false;
160    }
161
162    bool ret = stackWriter_->WriteWithPayloadTimeout(data, size, payload, payloadSize,
163                                                     std::bind(&HookSocketClient::PeerIsConnected, this));
164    if (!ret && config_->isBlocked) {
165        DisableHook();
166        return false;
167    }
168    ++g_flushCount;
169    if (g_flushCount % FLUSH_FLAG == 0) {
170        stackWriter_->Flush();
171    }
172    return true;
173}
174
175void HookSocketClient::Flush()
176{
177    if (stackWriter_ == nullptr || unixSocketClient_ == nullptr) {
178        return;
179    }
180    stackWriter_->Flush();
181}
182
183void HookSocketClient::DisableHook()
184{
185    bool expected = false;
186    if (g_disableHook.compare_exchange_strong(expected, true, std::memory_order_release, std::memory_order_relaxed)) {
187        HILOG_BASE_INFO(LOG_CORE, "%s", __func__);
188        if (disableHookCallback_) {
189            disableHookCallback_();
190        }
191    }
192}
193
194bool HookSocketClient::PeerIsConnected()
195{
196    return !unixSocketClient_->IsConnected();
197}
198
199bool HookSocketClient::SendNmdInfo()
200{
201    if (!config_->printNmd) {
202        return false;
203    }
204    void* nmdBuf = malloc(MEMCHECK_DETAILINFO_MAXSIZE);
205    if (nmdBuf == nullptr) {
206        return false;
207    }
208    struct OptArg opt = {0, reinterpret_cast<char*>(nmdBuf) };
209    malloc_stats_print(NmdWriteStat, &opt, "a");
210    StackRawData rawdata = {{{{0}}}};
211    rawdata.type = NMD_MSG;
212    if (stackWriter_) {
213        stackWriter_->WriteWithPayloadTimeout(&rawdata, sizeof(BaseStackRawData),
214                                              reinterpret_cast<int8_t*>(opt.buf), strlen(opt.buf) + 1,
215                                              std::bind(&HookSocketClient::PeerIsConnected, this));
216    }
217    free(nmdBuf);
218    return true;
219}
220
221bool HookSocketClient::SendEndMsg()
222{
223    StackRawData rawdata = {{{{0}}}};
224    rawdata.type = END_MSG;
225    if (stackWriter_) {
226        stackWriter_->WriteTimeout(&rawdata, sizeof(BaseStackRawData));
227    }
228    return true;
229}