1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "remote_auth_service.h"
17 
18 #include "iam_check.h"
19 #include "iam_defines.h"
20 #include "iam_logger.h"
21 #include "iam_para2str.h"
22 #include "iam_ptr.h"
23 
24 #include "context_factory.h"
25 #include "context_helper.h"
26 #include "context_pool.h"
27 #include "device_manager_util.h"
28 #include "hdi_wrapper.h"
29 #include "iam_para2str.h"
30 #include "remote_executor_stub.h"
31 #include "remote_iam_callback.h"
32 #include "remote_msg_util.h"
33 
34 #define LOG_TAG "USER_AUTH_SA"
35 
36 namespace OHOS {
37 namespace UserIam {
38 namespace UserAuth {
39 class RemoteAuthServiceImpl : public RemoteAuthService {
40 public:
41     static RemoteAuthServiceImpl &GetInstance();
42     RemoteAuthServiceImpl() = default;
43     ~RemoteAuthServiceImpl() override = default;
44 
45     bool Start() override;
46     void OnMessage(const std::string &connectionName, const std::string &srcEndPoint,
47         const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply) override;
48 
49     int32_t ProcStartRemoteAuthRequest(const std::string &connectionName, const std::shared_ptr<Attributes> &request,
50         std::shared_ptr<Attributes> &reply) override;
51     int32_t ProcQueryExecutorInfoRequest(const std::shared_ptr<Attributes> &request,
52         std::shared_ptr<Attributes> &reply) override;
53     int32_t ProcBeginExecuteRequest(const std::shared_ptr<Attributes> &request,
54         std::shared_ptr<Attributes> &reply) override;
55     int32_t ProcEndExecuteRequest(const std::shared_ptr<Attributes> &request,
56         std::shared_ptr<Attributes> &reply) override;
57 
58     uint64_t StartRemoteAuthContext(Authentication::AuthenticationPara para,
59         RemoteAuthContextParam remoteAuthContextParam,
60         const std::shared_ptr<ContextCallback> &contextCallback, int &lastError) override;
61 
62 private:
63     std::shared_ptr<ContextCallback> GetRemoteAuthContextCallback(std::string connectionName,
64         Authentication::AuthenticationPara para);
65     std::recursive_mutex mutex_;
66     std::map<uint64_t, std::shared_ptr<RemoteExecutorStub>> scheduleId2executorStub_;
67 };
68 
69 class RemoteAuthServiceImplConnectionListener : public ConnectionListener {
70 public:
71     RemoteAuthServiceImplConnectionListener() = default;
72     ~RemoteAuthServiceImplConnectionListener() override = default;
73 
74     void OnMessage(const std::string &connectionName, const std::string &srcEndPoint,
75         const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply) override
76     {
77         IF_FALSE_LOGE_AND_RETURN(request != nullptr);
78         IF_FALSE_LOGE_AND_RETURN(reply != nullptr);
79 
80         IAM_LOGI("connectionName: %{public}s, srcEndPoint: %{public}s", connectionName.c_str(), srcEndPoint.c_str());
81 
82         RemoteAuthServiceImpl::GetInstance().OnMessage(connectionName, srcEndPoint, request, reply);
83     }
84 
85     void OnConnectStatus(const std::string &connectionName, ConnectStatus connectStatus) override
86     {
87     }
88 };
89 
GetInstance()90 RemoteAuthServiceImpl &RemoteAuthServiceImpl::GetInstance()
91 {
92     static RemoteAuthServiceImpl remoteAuthServiceImpl;
93     return remoteAuthServiceImpl;
94 }
95 
Start()96 bool RemoteAuthServiceImpl::Start()
97 {
98     std::lock_guard<std::recursive_mutex> lock(mutex_);
99     IAM_LOGI("start");
100 
101     static auto callback = Common::MakeShared<RemoteAuthServiceImplConnectionListener>();
102     IF_FALSE_LOGE_AND_RETURN_VAL(callback != nullptr, false);
103     ResultCode registerResult = RemoteConnectionManager::GetInstance().RegisterConnectionListener(
104         REMOTE_SERVICE_ENDPOINT_NAME, callback);
105     IF_FALSE_LOGE_AND_RETURN_VAL(registerResult == SUCCESS, false);
106     IAM_LOGI("success");
107     return true;
108 }
109 
OnMessage(const std::string &connectionName, const std::string &srcEndPoint, const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)110 void RemoteAuthServiceImpl::OnMessage(const std::string &connectionName, const std::string &srcEndPoint,
111     const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)
112 {
113     IAM_LOGI("start");
114     std::lock_guard<std::recursive_mutex> lock(mutex_);
115 
116     IF_FALSE_LOGE_AND_RETURN(request != nullptr);
117     IF_FALSE_LOGE_AND_RETURN(reply != nullptr);
118 
119     int32_t msgType;
120     bool getMsgTypeRet = request->GetInt32Value(Attributes::ATTR_MSG_TYPE, msgType);
121     IF_FALSE_LOGE_AND_RETURN(getMsgTypeRet);
122 
123     IAM_LOGI("msgType is %{public}d", msgType);
124     int32_t resultCode = ResultCode::GENERAL_ERROR;
125     switch (msgType) {
126         case START_REMOTE_AUTH:
127             resultCode = ProcStartRemoteAuthRequest(connectionName, request, reply);
128             break;
129         case QUERY_EXECUTOR_INFO:
130             resultCode = ProcQueryExecutorInfoRequest(request, reply);
131             break;
132         case BEGIN_EXECUTE:
133             resultCode = ProcBeginExecuteRequest(request, reply);
134             break;
135         case END_EXECUTE:
136             resultCode = ProcEndExecuteRequest(request, reply);
137             break;
138         case KEEP_ALIVE:
139             resultCode = SUCCESS;
140             break;
141         default:
142             IAM_LOGE("unsupported request type: %{public}d", msgType);
143             break;
144     }
145 
146     bool setResultCodeRet = reply->SetInt32Value(Attributes::ATTR_RESULT_CODE, resultCode);
147     IF_FALSE_LOGE_AND_RETURN(setResultCodeRet);
148 
149     IAM_LOGI("success, msg result %{public}d", resultCode);
150 }
151 
StartRemoteAuthContext(Authentication::AuthenticationPara para, RemoteAuthContextParam remoteAuthContextParam, const std::shared_ptr<ContextCallback> &contextCallback, int &lastError)152 uint64_t RemoteAuthServiceImpl::StartRemoteAuthContext(Authentication::AuthenticationPara para,
153     RemoteAuthContextParam remoteAuthContextParam, const std::shared_ptr<ContextCallback> &contextCallback,
154     int &lastError)
155 {
156     IAM_LOGI("start");
157     IF_FALSE_LOGE_AND_RETURN_VAL(contextCallback != nullptr, BAD_CONTEXT_ID);
158     Attributes extraInfo;
159     std::shared_ptr<Context> context = ContextFactory::CreateRemoteAuthContext(para, remoteAuthContextParam,
160         contextCallback);
161     if (context == nullptr || !ContextPool::Instance().Insert(context)) {
162         IAM_LOGE("failed to insert context");
163         contextCallback->SetTraceAuthFinishReason("RemoteAuthServiceImpl StartRemoteAuthContext insert context fail");
164         contextCallback->OnResult(GENERAL_ERROR, extraInfo);
165         return BAD_CONTEXT_ID;
166     }
167     contextCallback->SetCleaner(ContextHelper::Cleaner(context));
168     contextCallback->SetTraceRequestContextId(context->GetContextId());
169     contextCallback->SetTraceAuthContextId(context->GetContextId());
170 
171     if (!context->Start()) {
172         lastError = context->GetLatestError();
173         IAM_LOGE("failed to start auth errorCode:%{public}d", lastError);
174         return BAD_CONTEXT_ID;
175     }
176     lastError = SUCCESS;
177     IAM_LOGI("success");
178     return context->GetContextId();
179 }
180 
GetRemoteAuthContextCallback(std::string connectionName, Authentication::AuthenticationPara para)181 std::shared_ptr<ContextCallback> RemoteAuthServiceImpl::GetRemoteAuthContextCallback(std::string connectionName,
182     Authentication::AuthenticationPara para)
183 {
184     sptr<IamCallbackInterface> callback(new RemoteIamCallback(connectionName));
185     IF_FALSE_LOGE_AND_RETURN_VAL(callback != nullptr, nullptr);
186 
187     auto contextCallback = ContextCallback::NewInstance(callback, TRACE_AUTH_USER_ALL);
188     IF_FALSE_LOGE_AND_RETURN_VAL(contextCallback != nullptr, nullptr);
189     contextCallback->SetTraceUserId(para.userId);
190     contextCallback->SetTraceAuthWidgetType(para.authType);
191     contextCallback->SetTraceAuthType(para.authType);
192     contextCallback->SetTraceAuthTrustLevel(para.atl);
193     contextCallback->SetTraceSdkVersion(para.sdkVersion);
194     contextCallback->SetTraceCallerName(para.callerName);
195     contextCallback->SetTraceCallerType(para.callerType);
196     return contextCallback;
197 }
198 
ProcStartRemoteAuthRequest(const std::string &connectionName, const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)199 int32_t RemoteAuthServiceImpl::ProcStartRemoteAuthRequest(const std::string &connectionName,
200     const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)
201 {
202     std::lock_guard<std::recursive_mutex> lock(mutex_);
203     IAM_LOGI("start");
204     AuthParamInner authParam = {};
205     bool getAuthParamRet = RemoteMsgUtil::DecodeAuthParam(*request, authParam);
206     IF_FALSE_LOGE_AND_RETURN_VAL(getAuthParamRet, GENERAL_ERROR);
207 
208     std::string collectorNetworkId;
209     bool getCollectorNetworkIdRet = request->GetStringValue(Attributes::ATTR_COLLECTOR_NETWORK_ID, collectorNetworkId);
210     IF_FALSE_LOGE_AND_RETURN_VAL(getCollectorNetworkIdRet, GENERAL_ERROR);
211 
212     uint32_t collectorTokenId;
213     bool getCollectorTokenIdRet = request->GetUint32Value(Attributes::ATTR_COLLECTOR_TOKEN_ID, collectorTokenId);
214     IF_FALSE_LOGE_AND_RETURN_VAL(getCollectorTokenIdRet, GENERAL_ERROR);
215 
216     Authentication::AuthenticationPara para = {};
217     para.userId = authParam.userId;
218     para.authType = authParam.authType;
219     para.atl = authParam.authTrustLevel;
220     para.collectorTokenId = collectorTokenId;
221     para.challenge = authParam.challenge;
222     para.sdkVersion = INNER_API_VERSION_10000;
223 
224     bool getCallerNameRet = request->GetStringValue(Attributes::ATTR_CALLER_NAME, para.callerName);
225     IF_FALSE_LOGE_AND_RETURN_VAL(getCallerNameRet, GENERAL_ERROR);
226     bool getCallerTypeRet = request->GetInt32Value(Attributes::ATTR_CALLER_TYPE, para.callerType);
227     IF_FALSE_LOGE_AND_RETURN_VAL(getCallerTypeRet, GENERAL_ERROR);
228 
229     RemoteAuthContextParam remoteAuthContextParam;
230     remoteAuthContextParam.authType = authParam.authType;
231     remoteAuthContextParam.connectionName = connectionName;
232     remoteAuthContextParam.collectorNetworkId = collectorNetworkId;
233     remoteAuthContextParam.executorInfoMsg = request->Serialize();
234 
235     auto contextCallback = GetRemoteAuthContextCallback(connectionName, para);
236     IF_FALSE_LOGE_AND_RETURN_VAL(contextCallback != nullptr, GENERAL_ERROR);
237 
238     int32_t lastError;
239     auto contextId = StartRemoteAuthContext(para, remoteAuthContextParam, contextCallback, lastError);
240     IF_FALSE_LOGE_AND_RETURN_VAL(contextId != BAD_CONTEXT_ID, lastError);
241 
242     bool setContextIdRet = reply->SetUint64Value(Attributes::ATTR_CONTEXT_ID, contextId);
243     IF_FALSE_LOGE_AND_RETURN_VAL(setContextIdRet, GENERAL_ERROR);
244 
245     IAM_LOGI("success");
246     return SUCCESS;
247 }
248 
ProcQueryExecutorInfoRequest(const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)249 int32_t RemoteAuthServiceImpl::ProcQueryExecutorInfoRequest(const std::shared_ptr<Attributes> &request,
250     std::shared_ptr<Attributes> &reply)
251 {
252     std::lock_guard<std::recursive_mutex> lock(mutex_);
253     IAM_LOGI("start");
254 
255     std::vector<int32_t> authTypes;
256     bool getAuthTypesRet = request->GetInt32ArrayValue(Attributes::ATTR_AUTH_TYPES, authTypes);
257     IF_FALSE_LOGE_AND_RETURN_VAL(getAuthTypesRet, GENERAL_ERROR);
258 
259     int32_t executorRole;
260     bool getExecutorRoleRet = request->GetInt32Value(Attributes::ATTR_EXECUTOR_ROLE, executorRole);
261     IF_FALSE_LOGE_AND_RETURN_VAL(getExecutorRoleRet, GENERAL_ERROR);
262 
263     std::string srcUdid;
264     bool getSrcUdidRet = request->GetStringValue(Attributes::ATTR_MSG_SRC_UDID, srcUdid);
265     IF_FALSE_LOGE_AND_RETURN_VAL(getSrcUdidRet, GENERAL_ERROR);
266 
267     bool getQueryExecutorInfoRet = RemoteMsgUtil::GetQueryExecutorInfoReply(authTypes, executorRole, srcUdid, *reply);
268     IF_FALSE_LOGE_AND_RETURN_VAL(getQueryExecutorInfoRet, GENERAL_ERROR);
269 
270     IAM_LOGI("success");
271     return SUCCESS;
272 }
273 
ProcBeginExecuteRequest(const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)274 int32_t RemoteAuthServiceImpl::ProcBeginExecuteRequest(const std::shared_ptr<Attributes> &request,
275     std::shared_ptr<Attributes> &reply)
276 {
277     std::lock_guard<std::recursive_mutex> lock(mutex_);
278     IAM_LOGI("start");
279 
280     std::shared_ptr<RemoteExecutorStub> executorStub = Common::MakeShared<RemoteExecutorStub>();
281     IF_FALSE_LOGE_AND_RETURN_VAL(executorStub != nullptr, GENERAL_ERROR);
282 
283     RemoteExecuteTrace traceInfo;
284     traceInfo.operationResult = executorStub->ProcBeginExecuteRequest(*request, traceInfo);
285     ReportRemoteExecuteProc(traceInfo);
286     IF_FALSE_LOGE_AND_RETURN_VAL(traceInfo.operationResult == SUCCESS, GENERAL_ERROR);
287 
288     scheduleId2executorStub_.emplace(traceInfo.scheduleId, executorStub);
289     IAM_LOGI("scheduleId %{public}s begin execute success", GET_MASKED_STRING(traceInfo.scheduleId).c_str());
290     return SUCCESS;
291 }
292 
ProcEndExecuteRequest(const std::shared_ptr<Attributes> &request, std::shared_ptr<Attributes> &reply)293 int32_t RemoteAuthServiceImpl::ProcEndExecuteRequest(const std::shared_ptr<Attributes> &request,
294     std::shared_ptr<Attributes> &reply)
295 {
296     std::lock_guard<std::recursive_mutex> lock(mutex_);
297     IAM_LOGI("start");
298 
299     uint64_t scheduleId;
300     bool getScheduleIdRet = request->GetUint64Value(Attributes::ATTR_SCHEDULE_ID, scheduleId);
301     IF_FALSE_LOGE_AND_RETURN_VAL(getScheduleIdRet, GENERAL_ERROR);
302 
303     auto it = scheduleId2executorStub_.find(scheduleId);
304     IF_FALSE_LOGE_AND_RETURN_VAL(it != scheduleId2executorStub_.end(), GENERAL_ERROR);
305     scheduleId2executorStub_.erase(it);
306     IAM_LOGI("scheduleId %{public}s end execute success", GET_MASKED_STRING(scheduleId).c_str());
307     return SUCCESS;
308 }
309 
GetInstance()310 RemoteAuthService &RemoteAuthService::GetInstance()
311 {
312     RemoteAuthServiceImpl &impl = RemoteAuthServiceImpl::GetInstance();
313     return impl;
314 }
315 } // namespace UserAuth
316 } // namespace UserIam
317 } // namespace OHOS
318