1/*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 *     http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#include "class_registry.h"
17
18#include <base/util/uid_util.h>
19
20#include "meta/base/interface_utils.h"
21
22META_BEGIN_NAMESPACE()
23
24void ClassRegistry::Clear()
25{
26    std::unique_lock lock { mutex_ };
27    objectFactories_.clear();
28}
29
30bool ClassRegistry::Unregister(const IObjectFactory::Ptr& fac)
31{
32    if (!fac) {
33        CORE_LOG_E("ClassRegistry: Cannot unregister a null object factory");
34        return false;
35    }
36    size_t erased = 0;
37    {
38        std::unique_lock lock { mutex_ };
39        erased = objectFactories_.erase(fac->GetClassInfo());
40    }
41    if (erased) {
42        onUnregistered_->Invoke({ fac });
43        return true;
44    }
45    return false;
46}
47
48bool ClassRegistry::Register(const IObjectFactory::Ptr& fac)
49{
50    if (!fac) {
51        CORE_LOG_E("ClassRegistry: Cannot register a null object factory");
52        return false;
53    }
54    {
55        std::unique_lock lock { mutex_ };
56        auto& info = fac->GetClassInfo();
57        auto& i = objectFactories_[info];
58        if (i) {
59            CORE_LOG_W("ClassRegistry: Cannot register a class that was already registered [name=%s, uid=%s]",
60                info.Name().data(), info.Id().ToString().c_str());
61            return false;
62        }
63        i = fac;
64    }
65    onRegistered_->Invoke({ fac });
66    return true;
67}
68
69IObjectFactory::ConstPtr ClassRegistry::GetObjectFactory(const BASE_NS::Uid& uid) const
70{
71    std::shared_lock lock { mutex_ };
72    auto it = objectFactories_.find(uid);
73    return it != objectFactories_.end() ? it->second : nullptr;
74}
75
76BASE_NS::string ClassRegistry::GetClassName(BASE_NS::Uid uid) const
77{
78    std::shared_lock lock { mutex_ };
79    auto it = objectFactories_.find(uid);
80    return it != objectFactories_.end() ? BASE_NS::string(it->second->GetClassInfo().Name())
81                                        : BASE_NS::string("Unknown class id [") + BASE_NS::to_string(uid) + "]";
82}
83
84BASE_NS::vector<IClassInfo::ConstPtr> ClassRegistry::GetAllTypes(
85    ObjectCategoryBits category, bool strict, bool excludeDeprecated) const
86{
87    std::shared_lock lock { mutex_ };
88    BASE_NS::vector<IClassInfo::ConstPtr> infos;
89    for (auto&& v : objectFactories_) {
90        const auto& factory = v.second;
91        if (excludeDeprecated && (factory->GetClassInfo().category & ObjectCategoryBits::DEPRECATED)) {
92            // Omit DEPRECATED classes if excludeDeprecated flag is true
93            continue;
94        }
95        if (CheckCategoryBits(factory->GetClassInfo().category, category, strict)) {
96            infos.emplace_back(factory);
97        }
98    }
99    return infos;
100}
101
102BASE_NS::vector<IClassInfo::ConstPtr> ClassRegistry::GetAllTypes(
103    const BASE_NS::vector<BASE_NS::Uid>& interfaceUids, bool strict, bool excludeDeprecated) const
104{
105    std::shared_lock lock { mutex_ };
106    BASE_NS::vector<IClassInfo::ConstPtr> infos;
107    for (auto&& v : objectFactories_) {
108        const auto& factory = v.second;
109        if (factory->GetClassInfo().category & ObjectCategoryBits::INTERNAL) {
110            // Omit classes with INTERNAL flag from the list
111            continue;
112        }
113        if (excludeDeprecated && (factory->GetClassInfo().category & ObjectCategoryBits::DEPRECATED)) {
114            // Omit DEPRECATED classes if excludeDeprecated flag is true
115            continue;
116        }
117        if (CheckInterfaces(factory->GetClassInterfaces(), interfaceUids, strict)) {
118            infos.push_back(factory);
119        }
120    }
121    return infos;
122}
123
124META_END_NAMESPACE()
125