1/*
2 * Copyright (c) 2021-2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef OHOS_ROSEN_CLIENT_AGENT_MANAGER_H
17#define OHOS_ROSEN_CLIENT_AGENT_MANAGER_H
18
19#include <map>
20#include <mutex>
21#include <set>
22#include "agent_death_recipient.h"
23#include "window_manager_hilog.h"
24#include "ipc_skeleton.h"
25
26namespace OHOS {
27namespace Rosen {
28constexpr int32_t INVALID_PID_ID = -1;
29template <typename T1, typename T2>
30class ClientAgentContainer {
31public:
32    ClientAgentContainer();
33    virtual ~ClientAgentContainer() = default;
34
35    bool RegisterAgent(const sptr<T1>& agent, T2 type);
36    bool UnregisterAgent(const sptr<T1>& agent, T2 type);
37    std::set<sptr<T1>> GetAgentsByType(T2 type);
38    void SetAgentDeathCallback(std::function<void(const sptr<IRemoteObject>&)> callback);
39    int32_t GetAgentPid(const sptr<T1>& agent);
40
41private:
42    void RemoveAgent(const sptr<IRemoteObject>& remoteObject);
43    bool UnregisterAgentLocked(std::set<sptr<T1>>& agents, const sptr<IRemoteObject>& agent);
44
45    static constexpr HiviewDFX::HiLogLabel LABEL = {LOG_CORE, HILOG_DOMAIN_WINDOW, "ClientAgentContainer"};
46
47    struct finder_t {
48        explicit finder_t(sptr<IRemoteObject> remoteObject) : remoteObject_(remoteObject) {}
49
50        bool operator()(sptr<T1> agent)
51        {
52            if (agent == nullptr) {
53                WLOGFE("agent is invalid");
54                return false;
55            }
56            return agent->AsObject() == remoteObject_;
57        }
58
59        sptr<IRemoteObject> remoteObject_;
60    };
61
62    std::recursive_mutex mutex_;
63    std::map<T2, std::set<sptr<T1>>> agentMap_;
64    std::map<sptr<T1>, int32_t> agentPidMap_;
65    sptr<AgentDeathRecipient> deathRecipient_;
66    std::function<void(const sptr<IRemoteObject>&)> deathCallback_;
67};
68
69template<typename T1, typename T2>
70ClientAgentContainer<T1, T2>::ClientAgentContainer() : deathRecipient_(
71    new AgentDeathRecipient([this](const sptr<IRemoteObject>& remoteObject) { this->RemoveAgent(remoteObject); })) {}
72
73template<typename T1, typename T2>
74bool ClientAgentContainer<T1, T2>::RegisterAgent(const sptr<T1>& agent, T2 type)
75{
76    std::lock_guard<std::recursive_mutex> lock(mutex_);
77    if (agent == nullptr) {
78        WLOGFE("agent is invalid");
79        return false;
80    }
81    agentMap_[type].insert(agent);
82    agentPidMap_[agent] = IPCSkeleton::GetCallingPid();
83    if (deathRecipient_ == nullptr || !agent->AsObject()->AddDeathRecipient(deathRecipient_)) {
84        WLOGFI("failed to add death recipient");
85    }
86    return true;
87}
88
89template<typename T1, typename T2>
90bool ClientAgentContainer<T1, T2>::UnregisterAgent(const sptr<T1>& agent, T2 type)
91{
92    std::lock_guard<std::recursive_mutex> lock(mutex_);
93    if (agent == nullptr) {
94        WLOGFE("agent is invalid");
95        return false;
96    }
97    if (agentMap_.count(type) == 0) {
98        WLOGFD("repeat unregister agent");
99        return true;
100    }
101    auto& agents = agentMap_.at(type);
102    UnregisterAgentLocked(agents, agent->AsObject());
103    agent->AsObject()->RemoveDeathRecipient(deathRecipient_);
104    return true;
105}
106
107template<typename T1, typename T2>
108std::set<sptr<T1>> ClientAgentContainer<T1, T2>::GetAgentsByType(T2 type)
109{
110    std::lock_guard<std::recursive_mutex> lock(mutex_);
111    if (agentMap_.count(type) == 0) {
112        WLOGFD("no such type of agent registered! type:%{public}u", type);
113        return std::set<sptr<T1>>();
114    }
115    return agentMap_.at(type);
116}
117
118template<typename T1, typename T2>
119bool ClientAgentContainer<T1, T2>::UnregisterAgentLocked(std::set<sptr<T1>>& agents,
120    const sptr<IRemoteObject>& agent)
121{
122    if (agent == nullptr) {
123        WLOGFE("agent is invalid");
124        return false;
125    }
126    auto iter = std::find_if(agents.begin(), agents.end(), finder_t(agent));
127    if (iter == agents.end()) {
128        WLOGFD("could not find this agent");
129        return false;
130    }
131    auto agentPidIt = agentPidMap_.find(*iter);
132    if (agentPidIt != agentPidMap_.end()) {
133        int32_t agentPid = agentPidMap_[*iter];
134        agentPidMap_.erase(agentPidIt);
135        WLOGFD("agent pid: %{public}d unregistered", agentPid);
136    }
137    agents.erase(iter);
138    WLOGFD("agent unregistered");
139    return true;
140}
141
142template<typename T1, typename T2>
143void ClientAgentContainer<T1, T2>::RemoveAgent(const sptr<IRemoteObject>& remoteObject)
144{
145    WLOGFI("RemoveAgent");
146    if (remoteObject == nullptr) {
147        WLOGFE("remoteObject is invalid");
148        return;
149    }
150    if (deathCallback_ != nullptr) {
151        deathCallback_(remoteObject);
152    }
153    std::lock_guard<std::recursive_mutex> lock(mutex_);
154    static bool isEntryAgain = false;
155    if (isEntryAgain) {
156        WLOGFW("UnregisterAgentLocked entry again");
157    }
158    isEntryAgain = true;
159    for (auto& elem : agentMap_) {
160        if (UnregisterAgentLocked(elem.second, remoteObject)) {
161            break;
162        }
163    }
164    remoteObject->RemoveDeathRecipient(deathRecipient_);
165    isEntryAgain = false;
166}
167
168template<typename T1, typename T2>
169void ClientAgentContainer<T1, T2>::SetAgentDeathCallback(std::function<void(const sptr<IRemoteObject>&)> callback)
170{
171    deathCallback_ = callback;
172}
173
174template<typename T1, typename T2>
175int32_t ClientAgentContainer<T1, T2>::GetAgentPid(const sptr<T1>& agent)
176{
177    std::lock_guard<std::recursive_mutex> lock(mutex_);
178    if (agent == nullptr) {
179        WLOGFE("agent is invalid");
180        return INVALID_PID_ID;
181    }
182    if (agentPidMap_.count(agent) == 0) {
183        WLOGFE("agent pid not found");
184        return INVALID_PID_ID;
185    }
186    return agentPidMap_[agent];
187}
188}
189}
190#endif // OHOS_ROSEN_CLIENT_AGENT_MANAGER_H
191