1/*
2 * Copyright (c) 2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "LocalSocket.h"
17
18#include "PreviewerEngineLog.h"
19
20LocalSocket::LocalSocket() : pipeHandle(nullptr) {}
21
22LocalSocket::~LocalSocket() {}
23
24bool LocalSocket::ConnectToServer(std::string name, LocalSocket::OpenMode openMode, TransMode transMode)
25{
26    std::wstring tempName = std::wstring(name.begin(), name.end());
27
28    DWORD openModeWin = GetWinOpenMode(openMode);
29    pipeHandle = CreateFileW(tempName.c_str(), openModeWin, 0, nullptr, OPEN_EXISTING, 0, NULL);
30    if (pipeHandle == INVALID_HANDLE_VALUE) {
31        ELOG("LocalSocket::ConnectToServer CreateFileW failed: %d", GetLastError());
32        return false;
33    }
34
35    DWORD tranMode = GetWinTransMode(transMode);
36    if (!SetNamedPipeHandleState(pipeHandle, &tranMode, nullptr, nullptr)) {
37        ELOG("LocalSocket::ConnectToServer SetNamedPipeHandleState failed: %d", GetLastError());
38        return false;
39    }
40
41    return true;
42}
43
44std::string LocalSocket::GetTracePipeName(std::string baseName) const
45{
46    return std::string("\\\\.\\pipe\\") + baseName;
47}
48
49std::string LocalSocket::GetCommandPipeName(std::string baseName) const
50{
51    return std::string("\\\\.\\pipe\\") + baseName + "_commandPipe";
52}
53
54std::string LocalSocket::GetImagePipeName(std::string baseName) const
55{
56    return std::string("\\\\.\\pipe\\") + baseName + "_imagePipe";
57}
58
59void LocalSocket::DisconnectFromServer()
60{
61    CloseHandle(pipeHandle);
62}
63
64int64_t LocalSocket::ReadData(char* data, size_t length) const
65{
66    if (length > UINT32_MAX) {
67        ELOG("LocalSocket::ReadData length must < %d", UINT32_MAX);
68        return -1;
69    }
70
71    DWORD readSize = 0;
72    if (!PeekNamedPipe(pipeHandle, nullptr, 0, nullptr, &readSize, nullptr)) {
73        return 0;
74    }
75
76    if (readSize == 0) {
77        return 0;
78    }
79
80    if (!ReadFile(pipeHandle, data, static_cast<DWORD>(length), &readSize, NULL)) {
81        DWORD error = GetLastError();
82        ELOG("LocalSocket::ReadData ReadFile failed: %d", error);
83        return 0 - static_cast<int64_t>(error);
84    }
85    return readSize;
86}
87
88size_t LocalSocket::WriteData(const void* data, size_t length) const
89{
90    if (length > UINT32_MAX) {
91        ELOG("LocalSocket::WriteData length must < %d", UINT32_MAX);
92        return 0;
93    }
94
95    DWORD writeSize = 0;
96    if (!WriteFile(pipeHandle, data, static_cast<DWORD>(length), &writeSize, nullptr)) {
97        DWORD error = GetLastError();
98        ELOG("LocalSocket::WriteData WriteFile failed: %d", error);
99        return 0 - static_cast<size_t>(error);
100    }
101    return writeSize;
102}
103
104const LocalSocket& LocalSocket::operator<<(const std::string data) const
105{
106    WriteData(data.c_str(), data.length() + 1);
107    return *this;
108}
109
110const LocalSocket& LocalSocket::operator>>(std::string& data) const
111{
112    char c = '\255';
113    while (c != '\0' && ReadData(&c, 1) > 0) {
114        data.push_back(c);
115    }
116    return *this;
117}
118
119DWORD LocalSocket::GetWinOpenMode(LocalSocket::OpenMode mode) const
120{
121    DWORD openModeWin = GENERIC_READ;
122    switch (mode) {
123        case READ_ONLY:
124            openModeWin = GENERIC_READ;
125            break;
126        case WRITE_ONLY:
127            openModeWin = GENERIC_WRITE;
128            break;
129        case READ_WRITE:
130            openModeWin = GENERIC_READ | GENERIC_WRITE;
131    }
132    return openModeWin;
133}
134
135DWORD LocalSocket::GetWinTransMode(LocalSocket::TransMode mode) const
136{
137    DWORD transMode = PIPE_READMODE_BYTE;
138    switch (mode) {
139        case TRANS_BYTE:
140            transMode = PIPE_READMODE_BYTE;
141            break;
142        case TRANS_MESSAGE:
143            transMode = PIPE_READMODE_MESSAGE;
144            break;
145    }
146    return transMode;
147}
148