1#region Copyright notice and license
2// Protocol Buffers - Google's data interchange format
3// Copyright 2008 Google Inc.  All rights reserved.
4// https://developers.google.com/protocol-buffers/
5//
6// Redistribution and use in source and binary forms, with or without
7// modification, are permitted provided that the following conditions are
8// met:
9//
10//     * Redistributions of source code must retain the above copyright
11// notice, this list of conditions and the following disclaimer.
12//     * Redistributions in binary form must reproduce the above
13// copyright notice, this list of conditions and the following disclaimer
14// in the documentation and/or other materials provided with the
15// distribution.
16//     * Neither the name of Google Inc. nor the names of its
17// contributors may be used to endorse or promote products derived from
18// this software without specific prior written permission.
19//
20// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31#endregion
32
33using Google.Protobuf.Collections;
34using System;
35using System.Collections.Generic;
36using System.Linq;
37using System.Security;
38
39namespace Google.Protobuf
40{
41    /// <summary>
42    /// Methods for managing <see cref="ExtensionSet{TTarget}"/>s with null checking.
43    ///
44    /// Most users will not use this class directly and its API is experimental and subject to change.
45    /// </summary>
46    public static class ExtensionSet
47    {
48        private static bool TryGetValue<TTarget>(ref ExtensionSet<TTarget> set, Extension extension, out IExtensionValue value) where TTarget : IExtendableMessage<TTarget>
49        {
50            if (set == null)
51            {
52                value = null;
53                return false;
54            }
55            return set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value);
56        }
57
58        /// <summary>
59        /// Gets the value of the specified extension
60        /// </summary>
61        public static TValue Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
62        {
63            IExtensionValue value;
64            if (TryGetValue(ref set, extension, out value))
65            {
66                return ((ExtensionValue<TValue>)value).GetValue();
67            }
68            else
69            {
70                return extension.DefaultValue;
71            }
72        }
73
74        /// <summary>
75        /// Gets the value of the specified repeated extension or null if it doesn't exist in this set
76        /// </summary>
77        public static RepeatedField<TValue> Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
78        {
79            IExtensionValue value;
80            if (TryGetValue(ref set, extension, out value))
81            {
82                return ((RepeatedExtensionValue<TValue>)value).GetValue();
83            }
84            else
85            {
86                return null;
87            }
88        }
89
90        /// <summary>
91        /// Gets the value of the specified repeated extension, registering it if it doesn't exist
92        /// </summary>
93        public static RepeatedField<TValue> GetOrInitialize<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
94        {
95            IExtensionValue value;
96            if (set == null)
97            {
98                value = extension.CreateValue();
99                set = new ExtensionSet<TTarget>();
100                set.ValuesByNumber.Add(extension.FieldNumber, value);
101            }
102            else
103            {
104                if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value))
105                {
106                    value = extension.CreateValue();
107                    set.ValuesByNumber.Add(extension.FieldNumber, value);
108                }
109            }
110
111            return ((RepeatedExtensionValue<TValue>)value).GetValue();
112        }
113
114        /// <summary>
115        /// Sets the value of the specified extension. This will make a new instance of ExtensionSet if the set is null.
116        /// </summary>
117        public static void Set<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension, TValue value) where TTarget : IExtendableMessage<TTarget>
118        {
119            ProtoPreconditions.CheckNotNullUnconstrained(value, nameof(value));
120
121            IExtensionValue extensionValue;
122            if (set == null)
123            {
124                extensionValue = extension.CreateValue();
125                set = new ExtensionSet<TTarget>();
126                set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
127            }
128            else
129            {
130                if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out extensionValue))
131                {
132                    extensionValue = extension.CreateValue();
133                    set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
134                }
135            }
136
137            ((ExtensionValue<TValue>)extensionValue).SetValue(value);
138        }
139
140        /// <summary>
141        /// Gets whether the value of the specified extension is set
142        /// </summary>
143        public static bool Has<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
144        {
145            IExtensionValue value;
146            return TryGetValue(ref set, extension, out value);
147        }
148
149        /// <summary>
150        /// Clears the value of the specified extension
151        /// </summary>
152        public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
153        {
154            if (set == null)
155            {
156                return;
157            }
158            set.ValuesByNumber.Remove(extension.FieldNumber);
159            if (set.ValuesByNumber.Count == 0)
160            {
161                set = null;
162            }
163        }
164
165        /// <summary>
166        /// Clears the value of the specified extension
167        /// </summary>
168        public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
169        {
170            if (set == null)
171            {
172                return;
173            }
174            set.ValuesByNumber.Remove(extension.FieldNumber);
175            if (set.ValuesByNumber.Count == 0)
176            {
177                set = null;
178            }
179        }
180
181        /// <summary>
182        /// Tries to merge a field from the coded input, returning true if the field was merged.
183        /// If the set is null or the field was not otherwise merged, this returns false.
184        /// </summary>
185        public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, CodedInputStream stream) where TTarget : IExtendableMessage<TTarget>
186        {
187            ParseContext.Initialize(stream, out ParseContext ctx);
188            try
189            {
190                return TryMergeFieldFrom<TTarget>(ref set, ref ctx);
191            }
192            finally
193            {
194                ctx.CopyStateTo(stream);
195            }
196        }
197
198        /// <summary>
199        /// Tries to merge a field from the coded input, returning true if the field was merged.
200        /// If the set is null or the field was not otherwise merged, this returns false.
201        /// </summary>
202        public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, ref ParseContext ctx) where TTarget : IExtendableMessage<TTarget>
203        {
204            Extension extension;
205            int lastFieldNumber = WireFormat.GetTagFieldNumber(ctx.LastTag);
206
207            IExtensionValue extensionValue;
208            if (set != null && set.ValuesByNumber.TryGetValue(lastFieldNumber, out extensionValue))
209            {
210                extensionValue.MergeFrom(ref ctx);
211                return true;
212            }
213            else if (ctx.ExtensionRegistry != null && ctx.ExtensionRegistry.ContainsInputField(ctx.LastTag, typeof(TTarget), out extension))
214            {
215                IExtensionValue value = extension.CreateValue();
216                value.MergeFrom(ref ctx);
217                set = (set ?? new ExtensionSet<TTarget>());
218                set.ValuesByNumber.Add(extension.FieldNumber, value);
219                return true;
220            }
221            else
222            {
223                return false;
224            }
225        }
226
227        /// <summary>
228        /// Merges the second set into the first set, creating a new instance if first is null
229        /// </summary>
230        public static void MergeFrom<TTarget>(ref ExtensionSet<TTarget> first, ExtensionSet<TTarget> second) where TTarget : IExtendableMessage<TTarget>
231        {
232            if (second == null)
233            {
234                return;
235            }
236            if (first == null)
237            {
238                first = new ExtensionSet<TTarget>();
239            }
240            foreach (var pair in second.ValuesByNumber)
241            {
242                IExtensionValue value;
243                if (first.ValuesByNumber.TryGetValue(pair.Key, out value))
244                {
245                    value.MergeFrom(pair.Value);
246                }
247                else
248                {
249                    var cloned = pair.Value.Clone();
250                    first.ValuesByNumber[pair.Key] = cloned;
251                }
252            }
253        }
254
255        /// <summary>
256        /// Clones the set into a new set. If the set is null, this returns null
257        /// </summary>
258        public static ExtensionSet<TTarget> Clone<TTarget>(ExtensionSet<TTarget> set) where TTarget : IExtendableMessage<TTarget>
259        {
260            if (set == null)
261            {
262                return null;
263            }
264
265            var newSet = new ExtensionSet<TTarget>();
266            foreach (var pair in set.ValuesByNumber)
267            {
268                var cloned = pair.Value.Clone();
269                newSet.ValuesByNumber[pair.Key] = cloned;
270            }
271            return newSet;
272        }
273    }
274
275    /// <summary>
276    /// Used for keeping track of extensions in messages.
277    /// <see cref="IExtendableMessage{T}"/> methods route to this set.
278    ///
279    /// Most users will not need to use this class directly
280    /// </summary>
281    /// <typeparam name="TTarget">The message type that extensions in this set target</typeparam>
282    public sealed class ExtensionSet<TTarget> where TTarget : IExtendableMessage<TTarget>
283    {
284        internal Dictionary<int, IExtensionValue> ValuesByNumber { get; } = new Dictionary<int, IExtensionValue>();
285
286        /// <summary>
287        /// Gets a hash code of the set
288        /// </summary>
289        public override int GetHashCode()
290        {
291            int ret = typeof(TTarget).GetHashCode();
292            foreach (KeyValuePair<int, IExtensionValue> field in ValuesByNumber)
293            {
294                // Use ^ here to make the field order irrelevant.
295                int hash = field.Key.GetHashCode() ^ field.Value.GetHashCode();
296                ret ^= hash;
297            }
298            return ret;
299        }
300
301        /// <summary>
302        /// Returns whether this set is equal to the other object
303        /// </summary>
304        public override bool Equals(object other)
305        {
306            if (ReferenceEquals(this, other))
307            {
308                return true;
309            }
310            ExtensionSet<TTarget> otherSet = other as ExtensionSet<TTarget>;
311            if (ValuesByNumber.Count != otherSet.ValuesByNumber.Count)
312            {
313                return false;
314            }
315            foreach (var pair in ValuesByNumber)
316            {
317                IExtensionValue secondValue;
318                if (!otherSet.ValuesByNumber.TryGetValue(pair.Key, out secondValue))
319                {
320                    return false;
321                }
322                if (!pair.Value.Equals(secondValue))
323                {
324                    return false;
325                }
326            }
327            return true;
328        }
329
330        /// <summary>
331        /// Calculates the size of this extension set
332        /// </summary>
333        public int CalculateSize()
334        {
335            int size = 0;
336            foreach (var value in ValuesByNumber.Values)
337            {
338                size += value.CalculateSize();
339            }
340            return size;
341        }
342
343        /// <summary>
344        /// Writes the extension values in this set to the output stream
345        /// </summary>
346        public void WriteTo(CodedOutputStream stream)
347        {
348
349            WriteContext.Initialize(stream, out WriteContext ctx);
350            try
351            {
352                WriteTo(ref ctx);
353            }
354            finally
355            {
356                ctx.CopyStateTo(stream);
357            }
358        }
359
360        /// <summary>
361        /// Writes the extension values in this set to the write context
362        /// </summary>
363        [SecuritySafeCritical]
364        public void WriteTo(ref WriteContext ctx)
365        {
366            foreach (var value in ValuesByNumber.Values)
367            {
368                value.WriteTo(ref ctx);
369            }
370        }
371
372        internal bool IsInitialized()
373        {
374            return ValuesByNumber.Values.All(v => v.IsInitialized());
375        }
376    }
377}
378