1/*
2 * Copyright (c) 2021-2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef HISTREAMER_FOUNDATION_OSAL_BASE_SYNCHRONIZER_H
17#define HISTREAMER_FOUNDATION_OSAL_BASE_SYNCHRONIZER_H
18
19#include <functional>
20#include <map>
21#include <set>
22#include <string>
23
24#include "inner_api/common/log.h"
25#include "inner_api/osal/task/condition_variable.h"
26#include "inner_api/osal/task/mutex.h"
27
28namespace OHOS {
29namespace Media {
30template <typename SyncIdType, typename ResultType = void>
31class Synchronizer {
32public:
33    explicit Synchronizer(std::string name) : name_(std::move(name))
34    {
35    }
36
37    Synchronizer(const Synchronizer<SyncIdType, ResultType>&) = delete;
38
39    Synchronizer<SyncIdType, ResultType>& operator=(const Synchronizer<SyncIdType, ResultType>&) = delete;
40
41    virtual ~Synchronizer() = default;
42
43    void Wait(SyncIdType syncId, const std::function<void()>& asyncOps)
44    {
45        MEDIA_LOG_I("Synchronizer " PUBLIC_LOG_S " Wait for " PUBLIC_LOG_D32,
46                    name_.c_str(), static_cast<int>(syncId));
47        if (asyncOps) {
48            OSAL::ScopedLock lock(mutex_);
49            waitSet_.insert(syncId);
50            asyncOps();
51            cv_.Wait(lock, [this, syncId] { return syncMap_.find(syncId) != syncMap_.end(); });
52            syncMap_.erase(syncId);
53        }
54    }
55
56    bool WaitFor(SyncIdType syncId, const std::function<void()>& asyncOps, int timeoutMs)
57    {
58        MEDIA_LOG_I("Synchronizer " PUBLIC_LOG_S " Wait for " PUBLIC_LOG_D32 ", timeout: " PUBLIC_LOG_D32,
59                    name_.c_str(), static_cast<int>(syncId), timeoutMs);
60        if (!asyncOps) {
61            return false;
62        }
63        OSAL::ScopedLock lock(mutex_);
64        waitSet_.insert(syncId);
65        asyncOps();
66        auto rtv = cv_.WaitFor(lock, timeoutMs, [this, syncId] { return syncMap_.find(syncId) != syncMap_.end(); });
67        if (rtv) {
68            syncMap_.erase(syncId);
69        } else {
70            waitSet_.erase(syncId);
71        }
72        return rtv;
73    }
74
75    void Wait(SyncIdType syncId, const std::function<void()>& asyncOps, ResultType& result)
76    {
77        MEDIA_LOG_I("Synchronizer " PUBLIC_LOG_S " Wait for " PUBLIC_LOG_D32,
78                    name_.c_str(), static_cast<int>(syncId));
79        if (asyncOps) {
80            OSAL::ScopedLock lock(mutex_);
81            waitSet_.insert(syncId);
82            asyncOps();
83            cv_.Wait(lock, [this, syncId] { return syncMap_.find(syncId) != syncMap_.end(); });
84            result = syncMap_[syncId];
85            syncMap_.erase(syncId);
86        }
87    }
88
89    bool WaitFor(SyncIdType syncId, const std::function<bool()>& asyncOps, int timeoutMs, ResultType& result)
90    {
91        MEDIA_LOG_I("Synchronizer " PUBLIC_LOG_S " Wait for " PUBLIC_LOG_D32 ", timeout: " PUBLIC_LOG_D32,
92                    name_.c_str(), static_cast<int>(syncId), timeoutMs);
93        if (!asyncOps) {
94            return false;
95        }
96        OSAL::ScopedLock lock(mutex_);
97        waitSet_.insert(syncId);
98        if (!asyncOps()) {
99            waitSet_.erase(syncId);
100            return false;
101        }
102        auto rtv = cv_.WaitFor(lock, timeoutMs, [this, syncId] { return syncMap_.find(syncId) != syncMap_.end(); });
103        if (rtv) {
104            result = syncMap_[syncId];
105            syncMap_.erase(syncId);
106            MEDIA_LOG_D("Synchronizer " PUBLIC_LOG_S " Wait for " PUBLIC_LOG_D32 " return.", name_.c_str(),
107                        static_cast<int>(syncId));
108        } else {
109            waitSet_.erase(syncId);
110        }
111        return rtv;
112    }
113
114    void Notify(SyncIdType syncId, ResultType result = ResultType())
115    {
116        MEDIA_LOG_I("Synchronizer " PUBLIC_LOG_S " Notify: " PUBLIC_LOG_D32,
117                    name_.c_str(), static_cast<int>(syncId));
118        OSAL::ScopedLock lock(mutex_);
119        if (waitSet_.find(syncId) != waitSet_.end()) {
120            waitSet_.erase(syncId);
121            syncMap_.insert({syncId, result});
122            cv_.NotifyAll();
123        }
124    }
125
126private:
127    static constexpr OHOS::HiviewDFX::HiLogLabel LABEL = { LOG_CORE, LOG_DOMAIN_FOUNDATION, "Synchronizer" };
128    Mutex mutex_;
129    ConditionVariable cv_;
130    std::string name_;
131    std::map<SyncIdType, ResultType> syncMap_;
132    std::set<SyncIdType> waitSet_;
133};
134} // namespace Media
135} // namespace OHOS
136#endif // HISTREAMER_FOUNDATION_OSAL_BASE_SYNCHRONIZER_H
137