1 /*
2  * Copyright (c) 2021-2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "network/session_pool.h"
17 #include "dm_device_info.h"
18 #include "device/device_manager_agent.h"
19 #include "dfs_error.h"
20 #include "ipc/i_daemon.h"
21 
22 namespace OHOS {
23 namespace Storage {
24 namespace DistributedFile {
25 using namespace std;
26 namespace {
27 const int32_t MOUNT_DFS_COUNT_ONE = 1;
28 }
29 
OccupySession(int32_t sessionId, uint8_t linkType)30 void SessionPool::OccupySession(int32_t sessionId, uint8_t linkType)
31 {
32     lock_guard lock(sessionPoolLock_);
33     occupySession_[sessionId] = linkType;
34 }
35 
FindSession(int32_t sessionId)36 bool SessionPool::FindSession(int32_t sessionId)
37 {
38     lock_guard lock(sessionPoolLock_);
39     auto linkTypeIter = occupySession_.find(sessionId);
40     if (linkTypeIter != occupySession_.end()) {
41         return true;
42     } else {
43         return false;
44     }
45 }
46 
HoldSession(shared_ptr<BaseSession> session, const std::string backStage)47 void SessionPool::HoldSession(shared_ptr<BaseSession> session, const std::string backStage)
48 {
49     lock_guard lock(sessionPoolLock_);
50     if (DeviceConnectCountOnly(session)) {
51         LOGE("DeviceConnect Count Only");
52         return;
53     }
54     talker_->SinkSessionTokernel(session, backStage);
55     AddSessionToPool(session);
56 }
57 
ReleaseSession(const int32_t fd)58 uint8_t SessionPool::ReleaseSession(const int32_t fd)
59 {
60     uint8_t linkType = 0;
61     std::string cid = "";
62     lock_guard lock(sessionPoolLock_);
63     LOGI("ReleaseSession start, fd=%{public}d", fd);
64     if (fd < 0) {
65         LOGI("NOTIFY OFFLINE, fd=%{public}d, deviceConnectCount clear", fd);
66         if (deviceConnectCount_.empty()) {
67             return FileManagement::ERR_BAD_VALUE;
68         }
69         deviceConnectCount_.clear();
70         return FileManagement::ERR_BAD_VALUE;
71     }
72     for (auto iter = usrSpaceSessionPool_.begin(); iter != usrSpaceSessionPool_.end(); ++iter) {
73         if ((*iter)->GetHandle() == fd) {
74             auto linkTypeIter = occupySession_.find((*iter)->GetSessionId());
75             if (linkTypeIter != occupySession_.end()) {
76                 linkType = linkTypeIter->second;
77                 cid = (*iter)->GetCid();
78             }
79             if (DeviceDisconnectCountOnly(cid, linkType, false)) {
80                 continue;
81             }
82         }
83         if ((*iter)->GetHandle() == fd) {
84             auto linkTypeIter = occupySession_.find((*iter)->GetSessionId());
85             if (linkTypeIter != occupySession_.end()) {
86                 linkType = linkTypeIter->second;
87                 occupySession_.erase(linkTypeIter);
88                 (*iter)->Release();
89                 iter = usrSpaceSessionPool_.erase(iter);
90                 break;
91             }
92         }
93     }
94     return linkType;
95 }
96 
ReleaseSession(const string &cid, const uint8_t linkType)97 void SessionPool::ReleaseSession(const string &cid, const uint8_t linkType)
98 {
99     uint8_t mlinkType = 0;
100     lock_guard lock(sessionPoolLock_);
101     for (auto iter = usrSpaceSessionPool_.begin(); iter != usrSpaceSessionPool_.end(); ++iter) {
102         if ((*iter)->GetCid() == cid) {
103             auto linkTypeIter = occupySession_.find((*iter)->GetSessionId());
104             if (linkTypeIter != occupySession_.end()) {
105                 mlinkType = (linkTypeIter->second == 0) ? linkType : linkTypeIter->second;
106             }
107             if (mlinkType == linkType && DeviceDisconnectCountOnly(cid, linkType, false)) {
108                 continue;
109             }
110             if (mlinkType == linkType) {
111                 (*iter)->Release();
112                 occupySession_.erase(linkTypeIter);
113                 iter = usrSpaceSessionPool_.erase(iter);
114                 mlinkType = 0;
115                 continue;
116             }
117         }
118     }
119 }
120 
ReleaseAllSession()121 void SessionPool::ReleaseAllSession()
122 {
123     lock_guard lock(sessionPoolLock_);
124     for (auto iter = usrSpaceSessionPool_.begin(); iter != usrSpaceSessionPool_.end();) {
125         talker_->SinkOfflineCmdToKernel((*iter)->GetCid());
126         /* device offline, session release by softbus */
127         iter = usrSpaceSessionPool_.erase(iter);
128     }
129 }
130 
AddSessionToPool(shared_ptr<BaseSession> session)131 void SessionPool::AddSessionToPool(shared_ptr<BaseSession> session)
132 {
133     lock_guard lock(sessionPoolLock_);
134     usrSpaceSessionPool_.push_back(session);
135 }
136 
DeviceDisconnectCountOnly(const string &cid, const uint8_t linkType, bool needErase)137 bool SessionPool::DeviceDisconnectCountOnly(const string &cid, const uint8_t linkType, bool needErase)
138 {
139     if (linkType != LINK_TYPE_P2P) {
140         LOGI("DeviceDisconnectCountOnly return, linkType is %{public}d, not LINK_TYPE_P2P,", linkType);
141         return false;
142     }
143     if (cid.empty()) {
144         LOGE("fail to get networkId");
145         return false;
146     }
147     std::string key = cid + "_" + std::to_string(linkType);
148     auto itCount = deviceConnectCount_.find(key);
149     if (itCount == deviceConnectCount_.end()) {
150         LOGI("deviceConnectCount_ can not find %{public}s", Utils::GetAnonyString(key).c_str());
151         return false;
152     }
153     if (needErase) {
154         deviceConnectCount_.erase(itCount);
155         LOGI("[DeviceDisconnectCountOnly]  %{public}s, needErase", Utils::GetAnonyString(key).c_str());
156         return false;
157     }
158     if (itCount->second > MOUNT_DFS_COUNT_ONE) {
159         LOGI("[DeviceDisconnectCountOnly] networkId_linkType %{public}s has already established \
160             more than one link, count %{public}d, decrease count by one now",
161             Utils::GetAnonyString(key).c_str(), itCount->second);
162         deviceConnectCount_[key]--;
163         return true;
164     } else {
165         LOGI("[DeviceDisconnectCountOnly] networkId_linkType %{public}s erase now", Utils::GetAnonyString(key).c_str());
166         deviceConnectCount_.erase(itCount);
167     }
168     return false;
169 }
170 
DeviceConnectCountOnly(std::shared_ptr<BaseSession> session)171 bool SessionPool::DeviceConnectCountOnly(std::shared_ptr<BaseSession> session)
172 {
173     auto cid = session->GetCid();
174     if (cid.empty()) {
175         LOGE("fail to get networkId");
176         return false;
177     }
178     std::string key = "";
179     auto sessionId = session->GetSessionId();
180     auto it = occupySession_.find(sessionId);
181     if (it != occupySession_.end()) {
182         uint8_t linkType = it->second;
183         if (linkType != LINK_TYPE_P2P) {
184             LOGI("DeviceConnectCountOnly return, linkType is %{public}d, not LINK_TYPE_P2P,", linkType);
185             return false;
186         }
187         key = cid + "_" + std::to_string(linkType);
188     } else {
189         LOGE("occupySession find sessionId failed");
190         return false;
191     }
192     auto itCount = deviceConnectCount_.find(key);
193     if (itCount != deviceConnectCount_.end() && itCount->second > 0) {
194         LOGI("[DeviceConnectCountOnly] networkId_linkType %{public}s has already established a link, \
195             count %{public}d, increase count by one now", Utils::GetAnonyString(key).c_str(), itCount->second);
196         deviceConnectCount_[key]++;
197         return true;
198     } else {
199         LOGI("[DeviceConnectCountOnly] networkId_linkType %{public}s increase count by one now",
200             Utils::GetAnonyString(key).c_str());
201         deviceConnectCount_[key]++;
202     }
203     return false;
204 }
205 
206 } // namespace DistributedFile
207 } // namespace Storage
208 } // namespace OHOS
209