1ffe3c632Sopenharmony_ci#!/usr/bin/python2.4
2ffe3c632Sopenharmony_ci#
3ffe3c632Sopenharmony_ci# Copyright 2008 Google Inc.
4ffe3c632Sopenharmony_ci#
5ffe3c632Sopenharmony_ci# Licensed under the Apache License, Version 2.0 (the "License");
6ffe3c632Sopenharmony_ci# you may not use this file except in compliance with the License.
7ffe3c632Sopenharmony_ci# You may obtain a copy of the License at
8ffe3c632Sopenharmony_ci#
9ffe3c632Sopenharmony_ci#      http://www.apache.org/licenses/LICENSE-2.0
10ffe3c632Sopenharmony_ci#
11ffe3c632Sopenharmony_ci# Unless required by applicable law or agreed to in writing, software
12ffe3c632Sopenharmony_ci# distributed under the License is distributed on an "AS IS" BASIS,
13ffe3c632Sopenharmony_ci# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14ffe3c632Sopenharmony_ci# See the License for the specific language governing permissions and
15ffe3c632Sopenharmony_ci# limitations under the License.
16ffe3c632Sopenharmony_ci
17ffe3c632Sopenharmony_ci# This file is used for testing.  The original is at:
18ffe3c632Sopenharmony_ci#   http://code.google.com/p/pymox/
19ffe3c632Sopenharmony_ci
20ffe3c632Sopenharmony_ci"""Mox, an object-mocking framework for Python.
21ffe3c632Sopenharmony_ci
22ffe3c632Sopenharmony_ciMox works in the record-replay-verify paradigm.  When you first create
23ffe3c632Sopenharmony_cia mock object, it is in record mode.  You then programmatically set
24ffe3c632Sopenharmony_cithe expected behavior of the mock object (what methods are to be
25ffe3c632Sopenharmony_cicalled on it, with what parameters, what they should return, and in
26ffe3c632Sopenharmony_ciwhat order).
27ffe3c632Sopenharmony_ci
28ffe3c632Sopenharmony_ciOnce you have set up the expected mock behavior, you put it in replay
29ffe3c632Sopenharmony_cimode.  Now the mock responds to method calls just as you told it to.
30ffe3c632Sopenharmony_ciIf an unexpected method (or an expected method with unexpected
31ffe3c632Sopenharmony_ciparameters) is called, then an exception will be raised.
32ffe3c632Sopenharmony_ci
33ffe3c632Sopenharmony_ciOnce you are done interacting with the mock, you need to verify that
34ffe3c632Sopenharmony_ciall the expected interactions occurred.  (Maybe your code exited
35ffe3c632Sopenharmony_ciprematurely without calling some cleanup method!)  The verify phase
36ffe3c632Sopenharmony_ciensures that every expected method was called; otherwise, an exception
37ffe3c632Sopenharmony_ciwill be raised.
38ffe3c632Sopenharmony_ci
39ffe3c632Sopenharmony_ciSuggested usage / workflow:
40ffe3c632Sopenharmony_ci
41ffe3c632Sopenharmony_ci  # Create Mox factory
42ffe3c632Sopenharmony_ci  my_mox = Mox()
43ffe3c632Sopenharmony_ci
44ffe3c632Sopenharmony_ci  # Create a mock data access object
45ffe3c632Sopenharmony_ci  mock_dao = my_mox.CreateMock(DAOClass)
46ffe3c632Sopenharmony_ci
47ffe3c632Sopenharmony_ci  # Set up expected behavior
48ffe3c632Sopenharmony_ci  mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49ffe3c632Sopenharmony_ci  mock_dao.DeletePerson(person)
50ffe3c632Sopenharmony_ci
51ffe3c632Sopenharmony_ci  # Put mocks in replay mode
52ffe3c632Sopenharmony_ci  my_mox.ReplayAll()
53ffe3c632Sopenharmony_ci
54ffe3c632Sopenharmony_ci  # Inject mock object and run test
55ffe3c632Sopenharmony_ci  controller.SetDao(mock_dao)
56ffe3c632Sopenharmony_ci  controller.DeletePersonById('1')
57ffe3c632Sopenharmony_ci
58ffe3c632Sopenharmony_ci  # Verify all methods were called as expected
59ffe3c632Sopenharmony_ci  my_mox.VerifyAll()
60ffe3c632Sopenharmony_ci"""
61ffe3c632Sopenharmony_ci
62ffe3c632Sopenharmony_cifrom collections import deque
63ffe3c632Sopenharmony_ciimport re
64ffe3c632Sopenharmony_ciimport types
65ffe3c632Sopenharmony_ciimport unittest
66ffe3c632Sopenharmony_ci
67ffe3c632Sopenharmony_ciimport stubout
68ffe3c632Sopenharmony_ci
69ffe3c632Sopenharmony_ciclass Error(AssertionError):
70ffe3c632Sopenharmony_ci  """Base exception for this module."""
71ffe3c632Sopenharmony_ci
72ffe3c632Sopenharmony_ci  pass
73ffe3c632Sopenharmony_ci
74ffe3c632Sopenharmony_ci
75ffe3c632Sopenharmony_ciclass ExpectedMethodCallsError(Error):
76ffe3c632Sopenharmony_ci  """Raised when Verify() is called before all expected methods have been called
77ffe3c632Sopenharmony_ci  """
78ffe3c632Sopenharmony_ci
79ffe3c632Sopenharmony_ci  def __init__(self, expected_methods):
80ffe3c632Sopenharmony_ci    """Init exception.
81ffe3c632Sopenharmony_ci
82ffe3c632Sopenharmony_ci    Args:
83ffe3c632Sopenharmony_ci      # expected_methods: A sequence of MockMethod objects that should have been
84ffe3c632Sopenharmony_ci      #   called.
85ffe3c632Sopenharmony_ci      expected_methods: [MockMethod]
86ffe3c632Sopenharmony_ci
87ffe3c632Sopenharmony_ci    Raises:
88ffe3c632Sopenharmony_ci      ValueError: if expected_methods contains no methods.
89ffe3c632Sopenharmony_ci    """
90ffe3c632Sopenharmony_ci
91ffe3c632Sopenharmony_ci    if not expected_methods:
92ffe3c632Sopenharmony_ci      raise ValueError("There must be at least one expected method")
93ffe3c632Sopenharmony_ci    Error.__init__(self)
94ffe3c632Sopenharmony_ci    self._expected_methods = expected_methods
95ffe3c632Sopenharmony_ci
96ffe3c632Sopenharmony_ci  def __str__(self):
97ffe3c632Sopenharmony_ci    calls = "\n".join(["%3d.  %s" % (i, m)
98ffe3c632Sopenharmony_ci                       for i, m in enumerate(self._expected_methods)])
99ffe3c632Sopenharmony_ci    return "Verify: Expected methods never called:\n%s" % (calls,)
100ffe3c632Sopenharmony_ci
101ffe3c632Sopenharmony_ci
102ffe3c632Sopenharmony_ciclass UnexpectedMethodCallError(Error):
103ffe3c632Sopenharmony_ci  """Raised when an unexpected method is called.
104ffe3c632Sopenharmony_ci
105ffe3c632Sopenharmony_ci  This can occur if a method is called with incorrect parameters, or out of the
106ffe3c632Sopenharmony_ci  specified order.
107ffe3c632Sopenharmony_ci  """
108ffe3c632Sopenharmony_ci
109ffe3c632Sopenharmony_ci  def __init__(self, unexpected_method, expected):
110ffe3c632Sopenharmony_ci    """Init exception.
111ffe3c632Sopenharmony_ci
112ffe3c632Sopenharmony_ci    Args:
113ffe3c632Sopenharmony_ci      # unexpected_method: MockMethod that was called but was not at the head of
114ffe3c632Sopenharmony_ci      #   the expected_method queue.
115ffe3c632Sopenharmony_ci      # expected: MockMethod or UnorderedGroup the method should have
116ffe3c632Sopenharmony_ci      #   been in.
117ffe3c632Sopenharmony_ci      unexpected_method: MockMethod
118ffe3c632Sopenharmony_ci      expected: MockMethod or UnorderedGroup
119ffe3c632Sopenharmony_ci    """
120ffe3c632Sopenharmony_ci
121ffe3c632Sopenharmony_ci    Error.__init__(self)
122ffe3c632Sopenharmony_ci    self._unexpected_method = unexpected_method
123ffe3c632Sopenharmony_ci    self._expected = expected
124ffe3c632Sopenharmony_ci
125ffe3c632Sopenharmony_ci  def __str__(self):
126ffe3c632Sopenharmony_ci    return "Unexpected method call: %s.  Expecting: %s" % \
127ffe3c632Sopenharmony_ci      (self._unexpected_method, self._expected)
128ffe3c632Sopenharmony_ci
129ffe3c632Sopenharmony_ci
130ffe3c632Sopenharmony_ciclass UnknownMethodCallError(Error):
131ffe3c632Sopenharmony_ci  """Raised if an unknown method is requested of the mock object."""
132ffe3c632Sopenharmony_ci
133ffe3c632Sopenharmony_ci  def __init__(self, unknown_method_name):
134ffe3c632Sopenharmony_ci    """Init exception.
135ffe3c632Sopenharmony_ci
136ffe3c632Sopenharmony_ci    Args:
137ffe3c632Sopenharmony_ci      # unknown_method_name: Method call that is not part of the mocked class's
138ffe3c632Sopenharmony_ci      #   public interface.
139ffe3c632Sopenharmony_ci      unknown_method_name: str
140ffe3c632Sopenharmony_ci    """
141ffe3c632Sopenharmony_ci
142ffe3c632Sopenharmony_ci    Error.__init__(self)
143ffe3c632Sopenharmony_ci    self._unknown_method_name = unknown_method_name
144ffe3c632Sopenharmony_ci
145ffe3c632Sopenharmony_ci  def __str__(self):
146ffe3c632Sopenharmony_ci    return "Method called is not a member of the object: %s" % \
147ffe3c632Sopenharmony_ci      self._unknown_method_name
148ffe3c632Sopenharmony_ci
149ffe3c632Sopenharmony_ci
150ffe3c632Sopenharmony_ciclass Mox(object):
151ffe3c632Sopenharmony_ci  """Mox: a factory for creating mock objects."""
152ffe3c632Sopenharmony_ci
153ffe3c632Sopenharmony_ci  # A list of types that should be stubbed out with MockObjects (as
154ffe3c632Sopenharmony_ci  # opposed to MockAnythings).
155ffe3c632Sopenharmony_ci  _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156ffe3c632Sopenharmony_ci                      types.ObjectType, types.TypeType]
157ffe3c632Sopenharmony_ci
158ffe3c632Sopenharmony_ci  def __init__(self):
159ffe3c632Sopenharmony_ci    """Initialize a new Mox."""
160ffe3c632Sopenharmony_ci
161ffe3c632Sopenharmony_ci    self._mock_objects = []
162ffe3c632Sopenharmony_ci    self.stubs = stubout.StubOutForTesting()
163ffe3c632Sopenharmony_ci
164ffe3c632Sopenharmony_ci  def CreateMock(self, class_to_mock):
165ffe3c632Sopenharmony_ci    """Create a new mock object.
166ffe3c632Sopenharmony_ci
167ffe3c632Sopenharmony_ci    Args:
168ffe3c632Sopenharmony_ci      # class_to_mock: the class to be mocked
169ffe3c632Sopenharmony_ci      class_to_mock: class
170ffe3c632Sopenharmony_ci
171ffe3c632Sopenharmony_ci    Returns:
172ffe3c632Sopenharmony_ci      MockObject that can be used as the class_to_mock would be.
173ffe3c632Sopenharmony_ci    """
174ffe3c632Sopenharmony_ci
175ffe3c632Sopenharmony_ci    new_mock = MockObject(class_to_mock)
176ffe3c632Sopenharmony_ci    self._mock_objects.append(new_mock)
177ffe3c632Sopenharmony_ci    return new_mock
178ffe3c632Sopenharmony_ci
179ffe3c632Sopenharmony_ci  def CreateMockAnything(self):
180ffe3c632Sopenharmony_ci    """Create a mock that will accept any method calls.
181ffe3c632Sopenharmony_ci
182ffe3c632Sopenharmony_ci    This does not enforce an interface.
183ffe3c632Sopenharmony_ci    """
184ffe3c632Sopenharmony_ci
185ffe3c632Sopenharmony_ci    new_mock = MockAnything()
186ffe3c632Sopenharmony_ci    self._mock_objects.append(new_mock)
187ffe3c632Sopenharmony_ci    return new_mock
188ffe3c632Sopenharmony_ci
189ffe3c632Sopenharmony_ci  def ReplayAll(self):
190ffe3c632Sopenharmony_ci    """Set all mock objects to replay mode."""
191ffe3c632Sopenharmony_ci
192ffe3c632Sopenharmony_ci    for mock_obj in self._mock_objects:
193ffe3c632Sopenharmony_ci      mock_obj._Replay()
194ffe3c632Sopenharmony_ci
195ffe3c632Sopenharmony_ci
196ffe3c632Sopenharmony_ci  def VerifyAll(self):
197ffe3c632Sopenharmony_ci    """Call verify on all mock objects created."""
198ffe3c632Sopenharmony_ci
199ffe3c632Sopenharmony_ci    for mock_obj in self._mock_objects:
200ffe3c632Sopenharmony_ci      mock_obj._Verify()
201ffe3c632Sopenharmony_ci
202ffe3c632Sopenharmony_ci  def ResetAll(self):
203ffe3c632Sopenharmony_ci    """Call reset on all mock objects.  This does not unset stubs."""
204ffe3c632Sopenharmony_ci
205ffe3c632Sopenharmony_ci    for mock_obj in self._mock_objects:
206ffe3c632Sopenharmony_ci      mock_obj._Reset()
207ffe3c632Sopenharmony_ci
208ffe3c632Sopenharmony_ci  def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209ffe3c632Sopenharmony_ci    """Replace a method, attribute, etc. with a Mock.
210ffe3c632Sopenharmony_ci
211ffe3c632Sopenharmony_ci    This will replace a class or module with a MockObject, and everything else
212ffe3c632Sopenharmony_ci    (method, function, etc) with a MockAnything.  This can be overridden to
213ffe3c632Sopenharmony_ci    always use a MockAnything by setting use_mock_anything to True.
214ffe3c632Sopenharmony_ci
215ffe3c632Sopenharmony_ci    Args:
216ffe3c632Sopenharmony_ci      obj: A Python object (class, module, instance, callable).
217ffe3c632Sopenharmony_ci      attr_name: str.  The name of the attribute to replace with a mock.
218ffe3c632Sopenharmony_ci      use_mock_anything: bool. True if a MockAnything should be used regardless
219ffe3c632Sopenharmony_ci        of the type of attribute.
220ffe3c632Sopenharmony_ci    """
221ffe3c632Sopenharmony_ci
222ffe3c632Sopenharmony_ci    attr_to_replace = getattr(obj, attr_name)
223ffe3c632Sopenharmony_ci    if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224ffe3c632Sopenharmony_ci      stub = self.CreateMock(attr_to_replace)
225ffe3c632Sopenharmony_ci    else:
226ffe3c632Sopenharmony_ci      stub = self.CreateMockAnything()
227ffe3c632Sopenharmony_ci
228ffe3c632Sopenharmony_ci    self.stubs.Set(obj, attr_name, stub)
229ffe3c632Sopenharmony_ci
230ffe3c632Sopenharmony_ci  def UnsetStubs(self):
231ffe3c632Sopenharmony_ci    """Restore stubs to their original state."""
232ffe3c632Sopenharmony_ci
233ffe3c632Sopenharmony_ci    self.stubs.UnsetAll()
234ffe3c632Sopenharmony_ci
235ffe3c632Sopenharmony_cidef Replay(*args):
236ffe3c632Sopenharmony_ci  """Put mocks into Replay mode.
237ffe3c632Sopenharmony_ci
238ffe3c632Sopenharmony_ci  Args:
239ffe3c632Sopenharmony_ci    # args is any number of mocks to put into replay mode.
240ffe3c632Sopenharmony_ci  """
241ffe3c632Sopenharmony_ci
242ffe3c632Sopenharmony_ci  for mock in args:
243ffe3c632Sopenharmony_ci    mock._Replay()
244ffe3c632Sopenharmony_ci
245ffe3c632Sopenharmony_ci
246ffe3c632Sopenharmony_cidef Verify(*args):
247ffe3c632Sopenharmony_ci  """Verify mocks.
248ffe3c632Sopenharmony_ci
249ffe3c632Sopenharmony_ci  Args:
250ffe3c632Sopenharmony_ci    # args is any number of mocks to be verified.
251ffe3c632Sopenharmony_ci  """
252ffe3c632Sopenharmony_ci
253ffe3c632Sopenharmony_ci  for mock in args:
254ffe3c632Sopenharmony_ci    mock._Verify()
255ffe3c632Sopenharmony_ci
256ffe3c632Sopenharmony_ci
257ffe3c632Sopenharmony_cidef Reset(*args):
258ffe3c632Sopenharmony_ci  """Reset mocks.
259ffe3c632Sopenharmony_ci
260ffe3c632Sopenharmony_ci  Args:
261ffe3c632Sopenharmony_ci    # args is any number of mocks to be reset.
262ffe3c632Sopenharmony_ci  """
263ffe3c632Sopenharmony_ci
264ffe3c632Sopenharmony_ci  for mock in args:
265ffe3c632Sopenharmony_ci    mock._Reset()
266ffe3c632Sopenharmony_ci
267ffe3c632Sopenharmony_ci
268ffe3c632Sopenharmony_ciclass MockAnything:
269ffe3c632Sopenharmony_ci  """A mock that can be used to mock anything.
270ffe3c632Sopenharmony_ci
271ffe3c632Sopenharmony_ci  This is helpful for mocking classes that do not provide a public interface.
272ffe3c632Sopenharmony_ci  """
273ffe3c632Sopenharmony_ci
274ffe3c632Sopenharmony_ci  def __init__(self):
275ffe3c632Sopenharmony_ci    """ """
276ffe3c632Sopenharmony_ci    self._Reset()
277ffe3c632Sopenharmony_ci
278ffe3c632Sopenharmony_ci  def __getattr__(self, method_name):
279ffe3c632Sopenharmony_ci    """Intercept method calls on this object.
280ffe3c632Sopenharmony_ci
281ffe3c632Sopenharmony_ci     A new MockMethod is returned that is aware of the MockAnything's
282ffe3c632Sopenharmony_ci     state (record or replay).  The call will be recorded or replayed
283ffe3c632Sopenharmony_ci     by the MockMethod's __call__.
284ffe3c632Sopenharmony_ci
285ffe3c632Sopenharmony_ci    Args:
286ffe3c632Sopenharmony_ci      # method name: the name of the method being called.
287ffe3c632Sopenharmony_ci      method_name: str
288ffe3c632Sopenharmony_ci
289ffe3c632Sopenharmony_ci    Returns:
290ffe3c632Sopenharmony_ci      A new MockMethod aware of MockAnything's state (record or replay).
291ffe3c632Sopenharmony_ci    """
292ffe3c632Sopenharmony_ci
293ffe3c632Sopenharmony_ci    return self._CreateMockMethod(method_name)
294ffe3c632Sopenharmony_ci
295ffe3c632Sopenharmony_ci  def _CreateMockMethod(self, method_name):
296ffe3c632Sopenharmony_ci    """Create a new mock method call and return it.
297ffe3c632Sopenharmony_ci
298ffe3c632Sopenharmony_ci    Args:
299ffe3c632Sopenharmony_ci      # method name: the name of the method being called.
300ffe3c632Sopenharmony_ci      method_name: str
301ffe3c632Sopenharmony_ci
302ffe3c632Sopenharmony_ci    Returns:
303ffe3c632Sopenharmony_ci      A new MockMethod aware of MockAnything's state (record or replay).
304ffe3c632Sopenharmony_ci    """
305ffe3c632Sopenharmony_ci
306ffe3c632Sopenharmony_ci    return MockMethod(method_name, self._expected_calls_queue,
307ffe3c632Sopenharmony_ci                      self._replay_mode)
308ffe3c632Sopenharmony_ci
309ffe3c632Sopenharmony_ci  def __nonzero__(self):
310ffe3c632Sopenharmony_ci    """Return 1 for nonzero so the mock can be used as a conditional."""
311ffe3c632Sopenharmony_ci
312ffe3c632Sopenharmony_ci    return 1
313ffe3c632Sopenharmony_ci
314ffe3c632Sopenharmony_ci  def __eq__(self, rhs):
315ffe3c632Sopenharmony_ci    """Provide custom logic to compare objects."""
316ffe3c632Sopenharmony_ci
317ffe3c632Sopenharmony_ci    return (isinstance(rhs, MockAnything) and
318ffe3c632Sopenharmony_ci            self._replay_mode == rhs._replay_mode and
319ffe3c632Sopenharmony_ci            self._expected_calls_queue == rhs._expected_calls_queue)
320ffe3c632Sopenharmony_ci
321ffe3c632Sopenharmony_ci  def __ne__(self, rhs):
322ffe3c632Sopenharmony_ci    """Provide custom logic to compare objects."""
323ffe3c632Sopenharmony_ci
324ffe3c632Sopenharmony_ci    return not self == rhs
325ffe3c632Sopenharmony_ci
326ffe3c632Sopenharmony_ci  def _Replay(self):
327ffe3c632Sopenharmony_ci    """Start replaying expected method calls."""
328ffe3c632Sopenharmony_ci
329ffe3c632Sopenharmony_ci    self._replay_mode = True
330ffe3c632Sopenharmony_ci
331ffe3c632Sopenharmony_ci  def _Verify(self):
332ffe3c632Sopenharmony_ci    """Verify that all of the expected calls have been made.
333ffe3c632Sopenharmony_ci
334ffe3c632Sopenharmony_ci    Raises:
335ffe3c632Sopenharmony_ci      ExpectedMethodCallsError: if there are still more method calls in the
336ffe3c632Sopenharmony_ci        expected queue.
337ffe3c632Sopenharmony_ci    """
338ffe3c632Sopenharmony_ci
339ffe3c632Sopenharmony_ci    # If the list of expected calls is not empty, raise an exception
340ffe3c632Sopenharmony_ci    if self._expected_calls_queue:
341ffe3c632Sopenharmony_ci      # The last MultipleTimesGroup is not popped from the queue.
342ffe3c632Sopenharmony_ci      if (len(self._expected_calls_queue) == 1 and
343ffe3c632Sopenharmony_ci          isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344ffe3c632Sopenharmony_ci          self._expected_calls_queue[0].IsSatisfied()):
345ffe3c632Sopenharmony_ci        pass
346ffe3c632Sopenharmony_ci      else:
347ffe3c632Sopenharmony_ci        raise ExpectedMethodCallsError(self._expected_calls_queue)
348ffe3c632Sopenharmony_ci
349ffe3c632Sopenharmony_ci  def _Reset(self):
350ffe3c632Sopenharmony_ci    """Reset the state of this mock to record mode with an empty queue."""
351ffe3c632Sopenharmony_ci
352ffe3c632Sopenharmony_ci    # Maintain a list of method calls we are expecting
353ffe3c632Sopenharmony_ci    self._expected_calls_queue = deque()
354ffe3c632Sopenharmony_ci
355ffe3c632Sopenharmony_ci    # Make sure we are in setup mode, not replay mode
356ffe3c632Sopenharmony_ci    self._replay_mode = False
357ffe3c632Sopenharmony_ci
358ffe3c632Sopenharmony_ci
359ffe3c632Sopenharmony_ciclass MockObject(MockAnything, object):
360ffe3c632Sopenharmony_ci  """A mock object that simulates the public/protected interface of a class."""
361ffe3c632Sopenharmony_ci
362ffe3c632Sopenharmony_ci  def __init__(self, class_to_mock):
363ffe3c632Sopenharmony_ci    """Initialize a mock object.
364ffe3c632Sopenharmony_ci
365ffe3c632Sopenharmony_ci    This determines the methods and properties of the class and stores them.
366ffe3c632Sopenharmony_ci
367ffe3c632Sopenharmony_ci    Args:
368ffe3c632Sopenharmony_ci      # class_to_mock: class to be mocked
369ffe3c632Sopenharmony_ci      class_to_mock: class
370ffe3c632Sopenharmony_ci    """
371ffe3c632Sopenharmony_ci
372ffe3c632Sopenharmony_ci    # This is used to hack around the mixin/inheritance of MockAnything, which
373ffe3c632Sopenharmony_ci    # is not a proper object (it can be anything. :-)
374ffe3c632Sopenharmony_ci    MockAnything.__dict__['__init__'](self)
375ffe3c632Sopenharmony_ci
376ffe3c632Sopenharmony_ci    # Get a list of all the public and special methods we should mock.
377ffe3c632Sopenharmony_ci    self._known_methods = set()
378ffe3c632Sopenharmony_ci    self._known_vars = set()
379ffe3c632Sopenharmony_ci    self._class_to_mock = class_to_mock
380ffe3c632Sopenharmony_ci    for method in dir(class_to_mock):
381ffe3c632Sopenharmony_ci      if callable(getattr(class_to_mock, method)):
382ffe3c632Sopenharmony_ci        self._known_methods.add(method)
383ffe3c632Sopenharmony_ci      else:
384ffe3c632Sopenharmony_ci        self._known_vars.add(method)
385ffe3c632Sopenharmony_ci
386ffe3c632Sopenharmony_ci  def __getattr__(self, name):
387ffe3c632Sopenharmony_ci    """Intercept attribute request on this object.
388ffe3c632Sopenharmony_ci
389ffe3c632Sopenharmony_ci    If the attribute is a public class variable, it will be returned and not
390ffe3c632Sopenharmony_ci    recorded as a call.
391ffe3c632Sopenharmony_ci
392ffe3c632Sopenharmony_ci    If the attribute is not a variable, it is handled like a method
393ffe3c632Sopenharmony_ci    call. The method name is checked against the set of mockable
394ffe3c632Sopenharmony_ci    methods, and a new MockMethod is returned that is aware of the
395ffe3c632Sopenharmony_ci    MockObject's state (record or replay).  The call will be recorded
396ffe3c632Sopenharmony_ci    or replayed by the MockMethod's __call__.
397ffe3c632Sopenharmony_ci
398ffe3c632Sopenharmony_ci    Args:
399ffe3c632Sopenharmony_ci      # name: the name of the attribute being requested.
400ffe3c632Sopenharmony_ci      name: str
401ffe3c632Sopenharmony_ci
402ffe3c632Sopenharmony_ci    Returns:
403ffe3c632Sopenharmony_ci      Either a class variable or a new MockMethod that is aware of the state
404ffe3c632Sopenharmony_ci      of the mock (record or replay).
405ffe3c632Sopenharmony_ci
406ffe3c632Sopenharmony_ci    Raises:
407ffe3c632Sopenharmony_ci      UnknownMethodCallError if the MockObject does not mock the requested
408ffe3c632Sopenharmony_ci          method.
409ffe3c632Sopenharmony_ci    """
410ffe3c632Sopenharmony_ci
411ffe3c632Sopenharmony_ci    if name in self._known_vars:
412ffe3c632Sopenharmony_ci      return getattr(self._class_to_mock, name)
413ffe3c632Sopenharmony_ci
414ffe3c632Sopenharmony_ci    if name in self._known_methods:
415ffe3c632Sopenharmony_ci      return self._CreateMockMethod(name)
416ffe3c632Sopenharmony_ci
417ffe3c632Sopenharmony_ci    raise UnknownMethodCallError(name)
418ffe3c632Sopenharmony_ci
419ffe3c632Sopenharmony_ci  def __eq__(self, rhs):
420ffe3c632Sopenharmony_ci    """Provide custom logic to compare objects."""
421ffe3c632Sopenharmony_ci
422ffe3c632Sopenharmony_ci    return (isinstance(rhs, MockObject) and
423ffe3c632Sopenharmony_ci            self._class_to_mock == rhs._class_to_mock and
424ffe3c632Sopenharmony_ci            self._replay_mode == rhs._replay_mode and
425ffe3c632Sopenharmony_ci            self._expected_calls_queue == rhs._expected_calls_queue)
426ffe3c632Sopenharmony_ci
427ffe3c632Sopenharmony_ci  def __setitem__(self, key, value):
428ffe3c632Sopenharmony_ci    """Provide custom logic for mocking classes that support item assignment.
429ffe3c632Sopenharmony_ci
430ffe3c632Sopenharmony_ci    Args:
431ffe3c632Sopenharmony_ci      key: Key to set the value for.
432ffe3c632Sopenharmony_ci      value: Value to set.
433ffe3c632Sopenharmony_ci
434ffe3c632Sopenharmony_ci    Returns:
435ffe3c632Sopenharmony_ci      Expected return value in replay mode.  A MockMethod object for the
436ffe3c632Sopenharmony_ci      __setitem__ method that has already been called if not in replay mode.
437ffe3c632Sopenharmony_ci
438ffe3c632Sopenharmony_ci    Raises:
439ffe3c632Sopenharmony_ci      TypeError if the underlying class does not support item assignment.
440ffe3c632Sopenharmony_ci      UnexpectedMethodCallError if the object does not expect the call to
441ffe3c632Sopenharmony_ci        __setitem__.
442ffe3c632Sopenharmony_ci
443ffe3c632Sopenharmony_ci    """
444ffe3c632Sopenharmony_ci    setitem = self._class_to_mock.__dict__.get('__setitem__', None)
445ffe3c632Sopenharmony_ci
446ffe3c632Sopenharmony_ci    # Verify the class supports item assignment.
447ffe3c632Sopenharmony_ci    if setitem is None:
448ffe3c632Sopenharmony_ci      raise TypeError('object does not support item assignment')
449ffe3c632Sopenharmony_ci
450ffe3c632Sopenharmony_ci    # If we are in replay mode then simply call the mock __setitem__ method.
451ffe3c632Sopenharmony_ci    if self._replay_mode:
452ffe3c632Sopenharmony_ci      return MockMethod('__setitem__', self._expected_calls_queue,
453ffe3c632Sopenharmony_ci                        self._replay_mode)(key, value)
454ffe3c632Sopenharmony_ci
455ffe3c632Sopenharmony_ci
456ffe3c632Sopenharmony_ci    # Otherwise, create a mock method __setitem__.
457ffe3c632Sopenharmony_ci    return self._CreateMockMethod('__setitem__')(key, value)
458ffe3c632Sopenharmony_ci
459ffe3c632Sopenharmony_ci  def __getitem__(self, key):
460ffe3c632Sopenharmony_ci    """Provide custom logic for mocking classes that are subscriptable.
461ffe3c632Sopenharmony_ci
462ffe3c632Sopenharmony_ci    Args:
463ffe3c632Sopenharmony_ci      key: Key to return the value for.
464ffe3c632Sopenharmony_ci
465ffe3c632Sopenharmony_ci    Returns:
466ffe3c632Sopenharmony_ci      Expected return value in replay mode.  A MockMethod object for the
467ffe3c632Sopenharmony_ci      __getitem__ method that has already been called if not in replay mode.
468ffe3c632Sopenharmony_ci
469ffe3c632Sopenharmony_ci    Raises:
470ffe3c632Sopenharmony_ci      TypeError if the underlying class is not subscriptable.
471ffe3c632Sopenharmony_ci      UnexpectedMethodCallError if the object does not expect the call to
472ffe3c632Sopenharmony_ci        __setitem__.
473ffe3c632Sopenharmony_ci
474ffe3c632Sopenharmony_ci    """
475ffe3c632Sopenharmony_ci    getitem = self._class_to_mock.__dict__.get('__getitem__', None)
476ffe3c632Sopenharmony_ci
477ffe3c632Sopenharmony_ci    # Verify the class supports item assignment.
478ffe3c632Sopenharmony_ci    if getitem is None:
479ffe3c632Sopenharmony_ci      raise TypeError('unsubscriptable object')
480ffe3c632Sopenharmony_ci
481ffe3c632Sopenharmony_ci    # If we are in replay mode then simply call the mock __getitem__ method.
482ffe3c632Sopenharmony_ci    if self._replay_mode:
483ffe3c632Sopenharmony_ci      return MockMethod('__getitem__', self._expected_calls_queue,
484ffe3c632Sopenharmony_ci                        self._replay_mode)(key)
485ffe3c632Sopenharmony_ci
486ffe3c632Sopenharmony_ci
487ffe3c632Sopenharmony_ci    # Otherwise, create a mock method __getitem__.
488ffe3c632Sopenharmony_ci    return self._CreateMockMethod('__getitem__')(key)
489ffe3c632Sopenharmony_ci
490ffe3c632Sopenharmony_ci  def __call__(self, *params, **named_params):
491ffe3c632Sopenharmony_ci    """Provide custom logic for mocking classes that are callable."""
492ffe3c632Sopenharmony_ci
493ffe3c632Sopenharmony_ci    # Verify the class we are mocking is callable
494ffe3c632Sopenharmony_ci    callable = self._class_to_mock.__dict__.get('__call__', None)
495ffe3c632Sopenharmony_ci    if callable is None:
496ffe3c632Sopenharmony_ci      raise TypeError('Not callable')
497ffe3c632Sopenharmony_ci
498ffe3c632Sopenharmony_ci    # Because the call is happening directly on this object instead of a method,
499ffe3c632Sopenharmony_ci    # the call on the mock method is made right here
500ffe3c632Sopenharmony_ci    mock_method = self._CreateMockMethod('__call__')
501ffe3c632Sopenharmony_ci    return mock_method(*params, **named_params)
502ffe3c632Sopenharmony_ci
503ffe3c632Sopenharmony_ci  @property
504ffe3c632Sopenharmony_ci  def __class__(self):
505ffe3c632Sopenharmony_ci    """Return the class that is being mocked."""
506ffe3c632Sopenharmony_ci
507ffe3c632Sopenharmony_ci    return self._class_to_mock
508ffe3c632Sopenharmony_ci
509ffe3c632Sopenharmony_ci
510ffe3c632Sopenharmony_ciclass MockMethod(object):
511ffe3c632Sopenharmony_ci  """Callable mock method.
512ffe3c632Sopenharmony_ci
513ffe3c632Sopenharmony_ci  A MockMethod should act exactly like the method it mocks, accepting parameters
514ffe3c632Sopenharmony_ci  and returning a value, or throwing an exception (as specified).  When this
515ffe3c632Sopenharmony_ci  method is called, it can optionally verify whether the called method (name and
516ffe3c632Sopenharmony_ci  signature) matches the expected method.
517ffe3c632Sopenharmony_ci  """
518ffe3c632Sopenharmony_ci
519ffe3c632Sopenharmony_ci  def __init__(self, method_name, call_queue, replay_mode):
520ffe3c632Sopenharmony_ci    """Construct a new mock method.
521ffe3c632Sopenharmony_ci
522ffe3c632Sopenharmony_ci    Args:
523ffe3c632Sopenharmony_ci      # method_name: the name of the method
524ffe3c632Sopenharmony_ci      # call_queue: deque of calls, verify this call against the head, or add
525ffe3c632Sopenharmony_ci      #     this call to the queue.
526ffe3c632Sopenharmony_ci      # replay_mode: False if we are recording, True if we are verifying calls
527ffe3c632Sopenharmony_ci      #     against the call queue.
528ffe3c632Sopenharmony_ci      method_name: str
529ffe3c632Sopenharmony_ci      call_queue: list or deque
530ffe3c632Sopenharmony_ci      replay_mode: bool
531ffe3c632Sopenharmony_ci    """
532ffe3c632Sopenharmony_ci
533ffe3c632Sopenharmony_ci    self._name = method_name
534ffe3c632Sopenharmony_ci    self._call_queue = call_queue
535ffe3c632Sopenharmony_ci    if not isinstance(call_queue, deque):
536ffe3c632Sopenharmony_ci      self._call_queue = deque(self._call_queue)
537ffe3c632Sopenharmony_ci    self._replay_mode = replay_mode
538ffe3c632Sopenharmony_ci
539ffe3c632Sopenharmony_ci    self._params = None
540ffe3c632Sopenharmony_ci    self._named_params = None
541ffe3c632Sopenharmony_ci    self._return_value = None
542ffe3c632Sopenharmony_ci    self._exception = None
543ffe3c632Sopenharmony_ci    self._side_effects = None
544ffe3c632Sopenharmony_ci
545ffe3c632Sopenharmony_ci  def __call__(self, *params, **named_params):
546ffe3c632Sopenharmony_ci    """Log parameters and return the specified return value.
547ffe3c632Sopenharmony_ci
548ffe3c632Sopenharmony_ci    If the Mock(Anything/Object) associated with this call is in record mode,
549ffe3c632Sopenharmony_ci    this MockMethod will be pushed onto the expected call queue.  If the mock
550ffe3c632Sopenharmony_ci    is in replay mode, this will pop a MockMethod off the top of the queue and
551ffe3c632Sopenharmony_ci    verify this call is equal to the expected call.
552ffe3c632Sopenharmony_ci
553ffe3c632Sopenharmony_ci    Raises:
554ffe3c632Sopenharmony_ci      UnexpectedMethodCall if this call is supposed to match an expected method
555ffe3c632Sopenharmony_ci        call and it does not.
556ffe3c632Sopenharmony_ci    """
557ffe3c632Sopenharmony_ci
558ffe3c632Sopenharmony_ci    self._params = params
559ffe3c632Sopenharmony_ci    self._named_params = named_params
560ffe3c632Sopenharmony_ci
561ffe3c632Sopenharmony_ci    if not self._replay_mode:
562ffe3c632Sopenharmony_ci      self._call_queue.append(self)
563ffe3c632Sopenharmony_ci      return self
564ffe3c632Sopenharmony_ci
565ffe3c632Sopenharmony_ci    expected_method = self._VerifyMethodCall()
566ffe3c632Sopenharmony_ci
567ffe3c632Sopenharmony_ci    if expected_method._side_effects:
568ffe3c632Sopenharmony_ci      expected_method._side_effects(*params, **named_params)
569ffe3c632Sopenharmony_ci
570ffe3c632Sopenharmony_ci    if expected_method._exception:
571ffe3c632Sopenharmony_ci      raise expected_method._exception
572ffe3c632Sopenharmony_ci
573ffe3c632Sopenharmony_ci    return expected_method._return_value
574ffe3c632Sopenharmony_ci
575ffe3c632Sopenharmony_ci  def __getattr__(self, name):
576ffe3c632Sopenharmony_ci    """Raise an AttributeError with a helpful message."""
577ffe3c632Sopenharmony_ci
578ffe3c632Sopenharmony_ci    raise AttributeError('MockMethod has no attribute "%s". '
579ffe3c632Sopenharmony_ci        'Did you remember to put your mocks in replay mode?' % name)
580ffe3c632Sopenharmony_ci
581ffe3c632Sopenharmony_ci  def _PopNextMethod(self):
582ffe3c632Sopenharmony_ci    """Pop the next method from our call queue."""
583ffe3c632Sopenharmony_ci    try:
584ffe3c632Sopenharmony_ci      return self._call_queue.popleft()
585ffe3c632Sopenharmony_ci    except IndexError:
586ffe3c632Sopenharmony_ci      raise UnexpectedMethodCallError(self, None)
587ffe3c632Sopenharmony_ci
588ffe3c632Sopenharmony_ci  def _VerifyMethodCall(self):
589ffe3c632Sopenharmony_ci    """Verify the called method is expected.
590ffe3c632Sopenharmony_ci
591ffe3c632Sopenharmony_ci    This can be an ordered method, or part of an unordered set.
592ffe3c632Sopenharmony_ci
593ffe3c632Sopenharmony_ci    Returns:
594ffe3c632Sopenharmony_ci      The expected mock method.
595ffe3c632Sopenharmony_ci
596ffe3c632Sopenharmony_ci    Raises:
597ffe3c632Sopenharmony_ci      UnexpectedMethodCall if the method called was not expected.
598ffe3c632Sopenharmony_ci    """
599ffe3c632Sopenharmony_ci
600ffe3c632Sopenharmony_ci    expected = self._PopNextMethod()
601ffe3c632Sopenharmony_ci
602ffe3c632Sopenharmony_ci    # Loop here, because we might have a MethodGroup followed by another
603ffe3c632Sopenharmony_ci    # group.
604ffe3c632Sopenharmony_ci    while isinstance(expected, MethodGroup):
605ffe3c632Sopenharmony_ci      expected, method = expected.MethodCalled(self)
606ffe3c632Sopenharmony_ci      if method is not None:
607ffe3c632Sopenharmony_ci        return method
608ffe3c632Sopenharmony_ci
609ffe3c632Sopenharmony_ci    # This is a mock method, so just check equality.
610ffe3c632Sopenharmony_ci    if expected != self:
611ffe3c632Sopenharmony_ci      raise UnexpectedMethodCallError(self, expected)
612ffe3c632Sopenharmony_ci
613ffe3c632Sopenharmony_ci    return expected
614ffe3c632Sopenharmony_ci
615ffe3c632Sopenharmony_ci  def __str__(self):
616ffe3c632Sopenharmony_ci    params = ', '.join(
617ffe3c632Sopenharmony_ci        [repr(p) for p in self._params or []] +
618ffe3c632Sopenharmony_ci        ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619ffe3c632Sopenharmony_ci    desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
620ffe3c632Sopenharmony_ci    return desc
621ffe3c632Sopenharmony_ci
622ffe3c632Sopenharmony_ci  def __eq__(self, rhs):
623ffe3c632Sopenharmony_ci    """Test whether this MockMethod is equivalent to another MockMethod.
624ffe3c632Sopenharmony_ci
625ffe3c632Sopenharmony_ci    Args:
626ffe3c632Sopenharmony_ci      # rhs: the right hand side of the test
627ffe3c632Sopenharmony_ci      rhs: MockMethod
628ffe3c632Sopenharmony_ci    """
629ffe3c632Sopenharmony_ci
630ffe3c632Sopenharmony_ci    return (isinstance(rhs, MockMethod) and
631ffe3c632Sopenharmony_ci            self._name == rhs._name and
632ffe3c632Sopenharmony_ci            self._params == rhs._params and
633ffe3c632Sopenharmony_ci            self._named_params == rhs._named_params)
634ffe3c632Sopenharmony_ci
635ffe3c632Sopenharmony_ci  def __ne__(self, rhs):
636ffe3c632Sopenharmony_ci    """Test whether this MockMethod is not equivalent to another MockMethod.
637ffe3c632Sopenharmony_ci
638ffe3c632Sopenharmony_ci    Args:
639ffe3c632Sopenharmony_ci      # rhs: the right hand side of the test
640ffe3c632Sopenharmony_ci      rhs: MockMethod
641ffe3c632Sopenharmony_ci    """
642ffe3c632Sopenharmony_ci
643ffe3c632Sopenharmony_ci    return not self == rhs
644ffe3c632Sopenharmony_ci
645ffe3c632Sopenharmony_ci  def GetPossibleGroup(self):
646ffe3c632Sopenharmony_ci    """Returns a possible group from the end of the call queue or None if no
647ffe3c632Sopenharmony_ci    other methods are on the stack.
648ffe3c632Sopenharmony_ci    """
649ffe3c632Sopenharmony_ci
650ffe3c632Sopenharmony_ci    # Remove this method from the tail of the queue so we can add it to a group.
651ffe3c632Sopenharmony_ci    this_method = self._call_queue.pop()
652ffe3c632Sopenharmony_ci    assert this_method == self
653ffe3c632Sopenharmony_ci
654ffe3c632Sopenharmony_ci    # Determine if the tail of the queue is a group, or just a regular ordered
655ffe3c632Sopenharmony_ci    # mock method.
656ffe3c632Sopenharmony_ci    group = None
657ffe3c632Sopenharmony_ci    try:
658ffe3c632Sopenharmony_ci      group = self._call_queue[-1]
659ffe3c632Sopenharmony_ci    except IndexError:
660ffe3c632Sopenharmony_ci      pass
661ffe3c632Sopenharmony_ci
662ffe3c632Sopenharmony_ci    return group
663ffe3c632Sopenharmony_ci
664ffe3c632Sopenharmony_ci  def _CheckAndCreateNewGroup(self, group_name, group_class):
665ffe3c632Sopenharmony_ci    """Checks if the last method (a possible group) is an instance of our
666ffe3c632Sopenharmony_ci    group_class. Adds the current method to this group or creates a new one.
667ffe3c632Sopenharmony_ci
668ffe3c632Sopenharmony_ci    Args:
669ffe3c632Sopenharmony_ci
670ffe3c632Sopenharmony_ci      group_name: the name of the group.
671ffe3c632Sopenharmony_ci      group_class: the class used to create instance of this new group
672ffe3c632Sopenharmony_ci    """
673ffe3c632Sopenharmony_ci    group = self.GetPossibleGroup()
674ffe3c632Sopenharmony_ci
675ffe3c632Sopenharmony_ci    # If this is a group, and it is the correct group, add the method.
676ffe3c632Sopenharmony_ci    if isinstance(group, group_class) and group.group_name() == group_name:
677ffe3c632Sopenharmony_ci      group.AddMethod(self)
678ffe3c632Sopenharmony_ci      return self
679ffe3c632Sopenharmony_ci
680ffe3c632Sopenharmony_ci    # Create a new group and add the method.
681ffe3c632Sopenharmony_ci    new_group = group_class(group_name)
682ffe3c632Sopenharmony_ci    new_group.AddMethod(self)
683ffe3c632Sopenharmony_ci    self._call_queue.append(new_group)
684ffe3c632Sopenharmony_ci    return self
685ffe3c632Sopenharmony_ci
686ffe3c632Sopenharmony_ci  def InAnyOrder(self, group_name="default"):
687ffe3c632Sopenharmony_ci    """Move this method into a group of unordered calls.
688ffe3c632Sopenharmony_ci
689ffe3c632Sopenharmony_ci    A group of unordered calls must be defined together, and must be executed
690ffe3c632Sopenharmony_ci    in full before the next expected method can be called.  There can be
691ffe3c632Sopenharmony_ci    multiple groups that are expected serially, if they are given
692ffe3c632Sopenharmony_ci    different group names.  The same group name can be reused if there is a
693ffe3c632Sopenharmony_ci    standard method call, or a group with a different name, spliced between
694ffe3c632Sopenharmony_ci    usages.
695ffe3c632Sopenharmony_ci
696ffe3c632Sopenharmony_ci    Args:
697ffe3c632Sopenharmony_ci      group_name: the name of the unordered group.
698ffe3c632Sopenharmony_ci
699ffe3c632Sopenharmony_ci    Returns:
700ffe3c632Sopenharmony_ci      self
701ffe3c632Sopenharmony_ci    """
702ffe3c632Sopenharmony_ci    return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
703ffe3c632Sopenharmony_ci
704ffe3c632Sopenharmony_ci  def MultipleTimes(self, group_name="default"):
705ffe3c632Sopenharmony_ci    """Move this method into group of calls which may be called multiple times.
706ffe3c632Sopenharmony_ci
707ffe3c632Sopenharmony_ci    A group of repeating calls must be defined together, and must be executed in
708ffe3c632Sopenharmony_ci    full before the next expected method can be called.
709ffe3c632Sopenharmony_ci
710ffe3c632Sopenharmony_ci    Args:
711ffe3c632Sopenharmony_ci      group_name: the name of the unordered group.
712ffe3c632Sopenharmony_ci
713ffe3c632Sopenharmony_ci    Returns:
714ffe3c632Sopenharmony_ci      self
715ffe3c632Sopenharmony_ci    """
716ffe3c632Sopenharmony_ci    return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
717ffe3c632Sopenharmony_ci
718ffe3c632Sopenharmony_ci  def AndReturn(self, return_value):
719ffe3c632Sopenharmony_ci    """Set the value to return when this method is called.
720ffe3c632Sopenharmony_ci
721ffe3c632Sopenharmony_ci    Args:
722ffe3c632Sopenharmony_ci      # return_value can be anything.
723ffe3c632Sopenharmony_ci    """
724ffe3c632Sopenharmony_ci
725ffe3c632Sopenharmony_ci    self._return_value = return_value
726ffe3c632Sopenharmony_ci    return return_value
727ffe3c632Sopenharmony_ci
728ffe3c632Sopenharmony_ci  def AndRaise(self, exception):
729ffe3c632Sopenharmony_ci    """Set the exception to raise when this method is called.
730ffe3c632Sopenharmony_ci
731ffe3c632Sopenharmony_ci    Args:
732ffe3c632Sopenharmony_ci      # exception: the exception to raise when this method is called.
733ffe3c632Sopenharmony_ci      exception: Exception
734ffe3c632Sopenharmony_ci    """
735ffe3c632Sopenharmony_ci
736ffe3c632Sopenharmony_ci    self._exception = exception
737ffe3c632Sopenharmony_ci
738ffe3c632Sopenharmony_ci  def WithSideEffects(self, side_effects):
739ffe3c632Sopenharmony_ci    """Set the side effects that are simulated when this method is called.
740ffe3c632Sopenharmony_ci
741ffe3c632Sopenharmony_ci    Args:
742ffe3c632Sopenharmony_ci      side_effects: A callable which modifies the parameters or other relevant
743ffe3c632Sopenharmony_ci        state which a given test case depends on.
744ffe3c632Sopenharmony_ci
745ffe3c632Sopenharmony_ci    Returns:
746ffe3c632Sopenharmony_ci      Self for chaining with AndReturn and AndRaise.
747ffe3c632Sopenharmony_ci    """
748ffe3c632Sopenharmony_ci    self._side_effects = side_effects
749ffe3c632Sopenharmony_ci    return self
750ffe3c632Sopenharmony_ci
751ffe3c632Sopenharmony_ciclass Comparator:
752ffe3c632Sopenharmony_ci  """Base class for all Mox comparators.
753ffe3c632Sopenharmony_ci
754ffe3c632Sopenharmony_ci  A Comparator can be used as a parameter to a mocked method when the exact
755ffe3c632Sopenharmony_ci  value is not known.  For example, the code you are testing might build up a
756ffe3c632Sopenharmony_ci  long SQL string that is passed to your mock DAO. You're only interested that
757ffe3c632Sopenharmony_ci  the IN clause contains the proper primary keys, so you can set your mock
758ffe3c632Sopenharmony_ci  up as follows:
759ffe3c632Sopenharmony_ci
760ffe3c632Sopenharmony_ci  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
761ffe3c632Sopenharmony_ci
762ffe3c632Sopenharmony_ci  Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
763ffe3c632Sopenharmony_ci
764ffe3c632Sopenharmony_ci  A Comparator may replace one or more parameters, for example:
765ffe3c632Sopenharmony_ci  # return at most 10 rows
766ffe3c632Sopenharmony_ci  mock_dao.RunQuery(StrContains('SELECT'), 10)
767ffe3c632Sopenharmony_ci
768ffe3c632Sopenharmony_ci  or
769ffe3c632Sopenharmony_ci
770ffe3c632Sopenharmony_ci  # Return some non-deterministic number of rows
771ffe3c632Sopenharmony_ci  mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
772ffe3c632Sopenharmony_ci  """
773ffe3c632Sopenharmony_ci
774ffe3c632Sopenharmony_ci  def equals(self, rhs):
775ffe3c632Sopenharmony_ci    """Special equals method that all comparators must implement.
776ffe3c632Sopenharmony_ci
777ffe3c632Sopenharmony_ci    Args:
778ffe3c632Sopenharmony_ci      rhs: any python object
779ffe3c632Sopenharmony_ci    """
780ffe3c632Sopenharmony_ci
781ffe3c632Sopenharmony_ci    raise NotImplementedError('method must be implemented by a subclass.')
782ffe3c632Sopenharmony_ci
783ffe3c632Sopenharmony_ci  def __eq__(self, rhs):
784ffe3c632Sopenharmony_ci    return self.equals(rhs)
785ffe3c632Sopenharmony_ci
786ffe3c632Sopenharmony_ci  def __ne__(self, rhs):
787ffe3c632Sopenharmony_ci    return not self.equals(rhs)
788ffe3c632Sopenharmony_ci
789ffe3c632Sopenharmony_ci
790ffe3c632Sopenharmony_ciclass IsA(Comparator):
791ffe3c632Sopenharmony_ci  """This class wraps a basic Python type or class.  It is used to verify
792ffe3c632Sopenharmony_ci  that a parameter is of the given type or class.
793ffe3c632Sopenharmony_ci
794ffe3c632Sopenharmony_ci  Example:
795ffe3c632Sopenharmony_ci  mock_dao.Connect(IsA(DbConnectInfo))
796ffe3c632Sopenharmony_ci  """
797ffe3c632Sopenharmony_ci
798ffe3c632Sopenharmony_ci  def __init__(self, class_name):
799ffe3c632Sopenharmony_ci    """Initialize IsA
800ffe3c632Sopenharmony_ci
801ffe3c632Sopenharmony_ci    Args:
802ffe3c632Sopenharmony_ci      class_name: basic python type or a class
803ffe3c632Sopenharmony_ci    """
804ffe3c632Sopenharmony_ci
805ffe3c632Sopenharmony_ci    self._class_name = class_name
806ffe3c632Sopenharmony_ci
807ffe3c632Sopenharmony_ci  def equals(self, rhs):
808ffe3c632Sopenharmony_ci    """Check to see if the RHS is an instance of class_name.
809ffe3c632Sopenharmony_ci
810ffe3c632Sopenharmony_ci    Args:
811ffe3c632Sopenharmony_ci      # rhs: the right hand side of the test
812ffe3c632Sopenharmony_ci      rhs: object
813ffe3c632Sopenharmony_ci
814ffe3c632Sopenharmony_ci    Returns:
815ffe3c632Sopenharmony_ci      bool
816ffe3c632Sopenharmony_ci    """
817ffe3c632Sopenharmony_ci
818ffe3c632Sopenharmony_ci    try:
819ffe3c632Sopenharmony_ci      return isinstance(rhs, self._class_name)
820ffe3c632Sopenharmony_ci    except TypeError:
821ffe3c632Sopenharmony_ci      # Check raw types if there was a type error.  This is helpful for
822ffe3c632Sopenharmony_ci      # things like cStringIO.StringIO.
823ffe3c632Sopenharmony_ci      return type(rhs) == type(self._class_name)
824ffe3c632Sopenharmony_ci
825ffe3c632Sopenharmony_ci  def __repr__(self):
826ffe3c632Sopenharmony_ci    return str(self._class_name)
827ffe3c632Sopenharmony_ci
828ffe3c632Sopenharmony_ciclass IsAlmost(Comparator):
829ffe3c632Sopenharmony_ci  """Comparison class used to check whether a parameter is nearly equal
830ffe3c632Sopenharmony_ci  to a given value.  Generally useful for floating point numbers.
831ffe3c632Sopenharmony_ci
832ffe3c632Sopenharmony_ci  Example mock_dao.SetTimeout((IsAlmost(3.9)))
833ffe3c632Sopenharmony_ci  """
834ffe3c632Sopenharmony_ci
835ffe3c632Sopenharmony_ci  def __init__(self, float_value, places=7):
836ffe3c632Sopenharmony_ci    """Initialize IsAlmost.
837ffe3c632Sopenharmony_ci
838ffe3c632Sopenharmony_ci    Args:
839ffe3c632Sopenharmony_ci      float_value: The value for making the comparison.
840ffe3c632Sopenharmony_ci      places: The number of decimal places to round to.
841ffe3c632Sopenharmony_ci    """
842ffe3c632Sopenharmony_ci
843ffe3c632Sopenharmony_ci    self._float_value = float_value
844ffe3c632Sopenharmony_ci    self._places = places
845ffe3c632Sopenharmony_ci
846ffe3c632Sopenharmony_ci  def equals(self, rhs):
847ffe3c632Sopenharmony_ci    """Check to see if RHS is almost equal to float_value
848ffe3c632Sopenharmony_ci
849ffe3c632Sopenharmony_ci    Args:
850ffe3c632Sopenharmony_ci      rhs: the value to compare to float_value
851ffe3c632Sopenharmony_ci
852ffe3c632Sopenharmony_ci    Returns:
853ffe3c632Sopenharmony_ci      bool
854ffe3c632Sopenharmony_ci    """
855ffe3c632Sopenharmony_ci
856ffe3c632Sopenharmony_ci    try:
857ffe3c632Sopenharmony_ci      return round(rhs-self._float_value, self._places) == 0
858ffe3c632Sopenharmony_ci    except TypeError:
859ffe3c632Sopenharmony_ci      # This is probably because either float_value or rhs is not a number.
860ffe3c632Sopenharmony_ci      return False
861ffe3c632Sopenharmony_ci
862ffe3c632Sopenharmony_ci  def __repr__(self):
863ffe3c632Sopenharmony_ci    return str(self._float_value)
864ffe3c632Sopenharmony_ci
865ffe3c632Sopenharmony_ciclass StrContains(Comparator):
866ffe3c632Sopenharmony_ci  """Comparison class used to check whether a substring exists in a
867ffe3c632Sopenharmony_ci  string parameter.  This can be useful in mocking a database with SQL
868ffe3c632Sopenharmony_ci  passed in as a string parameter, for example.
869ffe3c632Sopenharmony_ci
870ffe3c632Sopenharmony_ci  Example:
871ffe3c632Sopenharmony_ci  mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
872ffe3c632Sopenharmony_ci  """
873ffe3c632Sopenharmony_ci
874ffe3c632Sopenharmony_ci  def __init__(self, search_string):
875ffe3c632Sopenharmony_ci    """Initialize.
876ffe3c632Sopenharmony_ci
877ffe3c632Sopenharmony_ci    Args:
878ffe3c632Sopenharmony_ci      # search_string: the string you are searching for
879ffe3c632Sopenharmony_ci      search_string: str
880ffe3c632Sopenharmony_ci    """
881ffe3c632Sopenharmony_ci
882ffe3c632Sopenharmony_ci    self._search_string = search_string
883ffe3c632Sopenharmony_ci
884ffe3c632Sopenharmony_ci  def equals(self, rhs):
885ffe3c632Sopenharmony_ci    """Check to see if the search_string is contained in the rhs string.
886ffe3c632Sopenharmony_ci
887ffe3c632Sopenharmony_ci    Args:
888ffe3c632Sopenharmony_ci      # rhs: the right hand side of the test
889ffe3c632Sopenharmony_ci      rhs: object
890ffe3c632Sopenharmony_ci
891ffe3c632Sopenharmony_ci    Returns:
892ffe3c632Sopenharmony_ci      bool
893ffe3c632Sopenharmony_ci    """
894ffe3c632Sopenharmony_ci
895ffe3c632Sopenharmony_ci    try:
896ffe3c632Sopenharmony_ci      return rhs.find(self._search_string) > -1
897ffe3c632Sopenharmony_ci    except Exception:
898ffe3c632Sopenharmony_ci      return False
899ffe3c632Sopenharmony_ci
900ffe3c632Sopenharmony_ci  def __repr__(self):
901ffe3c632Sopenharmony_ci    return '<str containing \'%s\'>' % self._search_string
902ffe3c632Sopenharmony_ci
903ffe3c632Sopenharmony_ci
904ffe3c632Sopenharmony_ciclass Regex(Comparator):
905ffe3c632Sopenharmony_ci  """Checks if a string matches a regular expression.
906ffe3c632Sopenharmony_ci
907ffe3c632Sopenharmony_ci  This uses a given regular expression to determine equality.
908ffe3c632Sopenharmony_ci  """
909ffe3c632Sopenharmony_ci
910ffe3c632Sopenharmony_ci  def __init__(self, pattern, flags=0):
911ffe3c632Sopenharmony_ci    """Initialize.
912ffe3c632Sopenharmony_ci
913ffe3c632Sopenharmony_ci    Args:
914ffe3c632Sopenharmony_ci      # pattern is the regular expression to search for
915ffe3c632Sopenharmony_ci      pattern: str
916ffe3c632Sopenharmony_ci      # flags passed to re.compile function as the second argument
917ffe3c632Sopenharmony_ci      flags: int
918ffe3c632Sopenharmony_ci    """
919ffe3c632Sopenharmony_ci
920ffe3c632Sopenharmony_ci    self.regex = re.compile(pattern, flags=flags)
921ffe3c632Sopenharmony_ci
922ffe3c632Sopenharmony_ci  def equals(self, rhs):
923ffe3c632Sopenharmony_ci    """Check to see if rhs matches regular expression pattern.
924ffe3c632Sopenharmony_ci
925ffe3c632Sopenharmony_ci    Returns:
926ffe3c632Sopenharmony_ci      bool
927ffe3c632Sopenharmony_ci    """
928ffe3c632Sopenharmony_ci
929ffe3c632Sopenharmony_ci    return self.regex.search(rhs) is not None
930ffe3c632Sopenharmony_ci
931ffe3c632Sopenharmony_ci  def __repr__(self):
932ffe3c632Sopenharmony_ci    s = '<regular expression \'%s\'' % self.regex.pattern
933ffe3c632Sopenharmony_ci    if self.regex.flags:
934ffe3c632Sopenharmony_ci      s += ', flags=%d' % self.regex.flags
935ffe3c632Sopenharmony_ci    s += '>'
936ffe3c632Sopenharmony_ci    return s
937ffe3c632Sopenharmony_ci
938ffe3c632Sopenharmony_ci
939ffe3c632Sopenharmony_ciclass In(Comparator):
940ffe3c632Sopenharmony_ci  """Checks whether an item (or key) is in a list (or dict) parameter.
941ffe3c632Sopenharmony_ci
942ffe3c632Sopenharmony_ci  Example:
943ffe3c632Sopenharmony_ci  mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
944ffe3c632Sopenharmony_ci  """
945ffe3c632Sopenharmony_ci
946ffe3c632Sopenharmony_ci  def __init__(self, key):
947ffe3c632Sopenharmony_ci    """Initialize.
948ffe3c632Sopenharmony_ci
949ffe3c632Sopenharmony_ci    Args:
950ffe3c632Sopenharmony_ci      # key is any thing that could be in a list or a key in a dict
951ffe3c632Sopenharmony_ci    """
952ffe3c632Sopenharmony_ci
953ffe3c632Sopenharmony_ci    self._key = key
954ffe3c632Sopenharmony_ci
955ffe3c632Sopenharmony_ci  def equals(self, rhs):
956ffe3c632Sopenharmony_ci    """Check to see whether key is in rhs.
957ffe3c632Sopenharmony_ci
958ffe3c632Sopenharmony_ci    Args:
959ffe3c632Sopenharmony_ci      rhs: dict
960ffe3c632Sopenharmony_ci
961ffe3c632Sopenharmony_ci    Returns:
962ffe3c632Sopenharmony_ci      bool
963ffe3c632Sopenharmony_ci    """
964ffe3c632Sopenharmony_ci
965ffe3c632Sopenharmony_ci    return self._key in rhs
966ffe3c632Sopenharmony_ci
967ffe3c632Sopenharmony_ci  def __repr__(self):
968ffe3c632Sopenharmony_ci    return '<sequence or map containing \'%s\'>' % self._key
969ffe3c632Sopenharmony_ci
970ffe3c632Sopenharmony_ci
971ffe3c632Sopenharmony_ciclass ContainsKeyValue(Comparator):
972ffe3c632Sopenharmony_ci  """Checks whether a key/value pair is in a dict parameter.
973ffe3c632Sopenharmony_ci
974ffe3c632Sopenharmony_ci  Example:
975ffe3c632Sopenharmony_ci  mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
976ffe3c632Sopenharmony_ci  """
977ffe3c632Sopenharmony_ci
978ffe3c632Sopenharmony_ci  def __init__(self, key, value):
979ffe3c632Sopenharmony_ci    """Initialize.
980ffe3c632Sopenharmony_ci
981ffe3c632Sopenharmony_ci    Args:
982ffe3c632Sopenharmony_ci      # key: a key in a dict
983ffe3c632Sopenharmony_ci      # value: the corresponding value
984ffe3c632Sopenharmony_ci    """
985ffe3c632Sopenharmony_ci
986ffe3c632Sopenharmony_ci    self._key = key
987ffe3c632Sopenharmony_ci    self._value = value
988ffe3c632Sopenharmony_ci
989ffe3c632Sopenharmony_ci  def equals(self, rhs):
990ffe3c632Sopenharmony_ci    """Check whether the given key/value pair is in the rhs dict.
991ffe3c632Sopenharmony_ci
992ffe3c632Sopenharmony_ci    Returns:
993ffe3c632Sopenharmony_ci      bool
994ffe3c632Sopenharmony_ci    """
995ffe3c632Sopenharmony_ci
996ffe3c632Sopenharmony_ci    try:
997ffe3c632Sopenharmony_ci      return rhs[self._key] == self._value
998ffe3c632Sopenharmony_ci    except Exception:
999ffe3c632Sopenharmony_ci      return False
1000ffe3c632Sopenharmony_ci
1001ffe3c632Sopenharmony_ci  def __repr__(self):
1002ffe3c632Sopenharmony_ci    return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1003ffe3c632Sopenharmony_ci
1004ffe3c632Sopenharmony_ci
1005ffe3c632Sopenharmony_ciclass SameElementsAs(Comparator):
1006ffe3c632Sopenharmony_ci  """Checks whether iterables contain the same elements (ignoring order).
1007ffe3c632Sopenharmony_ci
1008ffe3c632Sopenharmony_ci  Example:
1009ffe3c632Sopenharmony_ci  mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1010ffe3c632Sopenharmony_ci  """
1011ffe3c632Sopenharmony_ci
1012ffe3c632Sopenharmony_ci  def __init__(self, expected_seq):
1013ffe3c632Sopenharmony_ci    """Initialize.
1014ffe3c632Sopenharmony_ci
1015ffe3c632Sopenharmony_ci    Args:
1016ffe3c632Sopenharmony_ci      expected_seq: a sequence
1017ffe3c632Sopenharmony_ci    """
1018ffe3c632Sopenharmony_ci
1019ffe3c632Sopenharmony_ci    self._expected_seq = expected_seq
1020ffe3c632Sopenharmony_ci
1021ffe3c632Sopenharmony_ci  def equals(self, actual_seq):
1022ffe3c632Sopenharmony_ci    """Check to see whether actual_seq has same elements as expected_seq.
1023ffe3c632Sopenharmony_ci
1024ffe3c632Sopenharmony_ci    Args:
1025ffe3c632Sopenharmony_ci      actual_seq: sequence
1026ffe3c632Sopenharmony_ci
1027ffe3c632Sopenharmony_ci    Returns:
1028ffe3c632Sopenharmony_ci      bool
1029ffe3c632Sopenharmony_ci    """
1030ffe3c632Sopenharmony_ci
1031ffe3c632Sopenharmony_ci    try:
1032ffe3c632Sopenharmony_ci      expected = dict([(element, None) for element in self._expected_seq])
1033ffe3c632Sopenharmony_ci      actual = dict([(element, None) for element in actual_seq])
1034ffe3c632Sopenharmony_ci    except TypeError:
1035ffe3c632Sopenharmony_ci      # Fall back to slower list-compare if any of the objects are unhashable.
1036ffe3c632Sopenharmony_ci      expected = list(self._expected_seq)
1037ffe3c632Sopenharmony_ci      actual = list(actual_seq)
1038ffe3c632Sopenharmony_ci      expected.sort()
1039ffe3c632Sopenharmony_ci      actual.sort()
1040ffe3c632Sopenharmony_ci    return expected == actual
1041ffe3c632Sopenharmony_ci
1042ffe3c632Sopenharmony_ci  def __repr__(self):
1043ffe3c632Sopenharmony_ci    return '<sequence with same elements as \'%s\'>' % self._expected_seq
1044ffe3c632Sopenharmony_ci
1045ffe3c632Sopenharmony_ci
1046ffe3c632Sopenharmony_ciclass And(Comparator):
1047ffe3c632Sopenharmony_ci  """Evaluates one or more Comparators on RHS and returns an AND of the results.
1048ffe3c632Sopenharmony_ci  """
1049ffe3c632Sopenharmony_ci
1050ffe3c632Sopenharmony_ci  def __init__(self, *args):
1051ffe3c632Sopenharmony_ci    """Initialize.
1052ffe3c632Sopenharmony_ci
1053ffe3c632Sopenharmony_ci    Args:
1054ffe3c632Sopenharmony_ci      *args: One or more Comparator
1055ffe3c632Sopenharmony_ci    """
1056ffe3c632Sopenharmony_ci
1057ffe3c632Sopenharmony_ci    self._comparators = args
1058ffe3c632Sopenharmony_ci
1059ffe3c632Sopenharmony_ci  def equals(self, rhs):
1060ffe3c632Sopenharmony_ci    """Checks whether all Comparators are equal to rhs.
1061ffe3c632Sopenharmony_ci
1062ffe3c632Sopenharmony_ci    Args:
1063ffe3c632Sopenharmony_ci      # rhs: can be anything
1064ffe3c632Sopenharmony_ci
1065ffe3c632Sopenharmony_ci    Returns:
1066ffe3c632Sopenharmony_ci      bool
1067ffe3c632Sopenharmony_ci    """
1068ffe3c632Sopenharmony_ci
1069ffe3c632Sopenharmony_ci    for comparator in self._comparators:
1070ffe3c632Sopenharmony_ci      if not comparator.equals(rhs):
1071ffe3c632Sopenharmony_ci        return False
1072ffe3c632Sopenharmony_ci
1073ffe3c632Sopenharmony_ci    return True
1074ffe3c632Sopenharmony_ci
1075ffe3c632Sopenharmony_ci  def __repr__(self):
1076ffe3c632Sopenharmony_ci    return '<AND %s>' % str(self._comparators)
1077ffe3c632Sopenharmony_ci
1078ffe3c632Sopenharmony_ci
1079ffe3c632Sopenharmony_ciclass Or(Comparator):
1080ffe3c632Sopenharmony_ci  """Evaluates one or more Comparators on RHS and returns an OR of the results.
1081ffe3c632Sopenharmony_ci  """
1082ffe3c632Sopenharmony_ci
1083ffe3c632Sopenharmony_ci  def __init__(self, *args):
1084ffe3c632Sopenharmony_ci    """Initialize.
1085ffe3c632Sopenharmony_ci
1086ffe3c632Sopenharmony_ci    Args:
1087ffe3c632Sopenharmony_ci      *args: One or more Mox comparators
1088ffe3c632Sopenharmony_ci    """
1089ffe3c632Sopenharmony_ci
1090ffe3c632Sopenharmony_ci    self._comparators = args
1091ffe3c632Sopenharmony_ci
1092ffe3c632Sopenharmony_ci  def equals(self, rhs):
1093ffe3c632Sopenharmony_ci    """Checks whether any Comparator is equal to rhs.
1094ffe3c632Sopenharmony_ci
1095ffe3c632Sopenharmony_ci    Args:
1096ffe3c632Sopenharmony_ci      # rhs: can be anything
1097ffe3c632Sopenharmony_ci
1098ffe3c632Sopenharmony_ci    Returns:
1099ffe3c632Sopenharmony_ci      bool
1100ffe3c632Sopenharmony_ci    """
1101ffe3c632Sopenharmony_ci
1102ffe3c632Sopenharmony_ci    for comparator in self._comparators:
1103ffe3c632Sopenharmony_ci      if comparator.equals(rhs):
1104ffe3c632Sopenharmony_ci        return True
1105ffe3c632Sopenharmony_ci
1106ffe3c632Sopenharmony_ci    return False
1107ffe3c632Sopenharmony_ci
1108ffe3c632Sopenharmony_ci  def __repr__(self):
1109ffe3c632Sopenharmony_ci    return '<OR %s>' % str(self._comparators)
1110ffe3c632Sopenharmony_ci
1111ffe3c632Sopenharmony_ci
1112ffe3c632Sopenharmony_ciclass Func(Comparator):
1113ffe3c632Sopenharmony_ci  """Call a function that should verify the parameter passed in is correct.
1114ffe3c632Sopenharmony_ci
1115ffe3c632Sopenharmony_ci  You may need the ability to perform more advanced operations on the parameter
1116ffe3c632Sopenharmony_ci  in order to validate it.  You can use this to have a callable validate any
1117ffe3c632Sopenharmony_ci  parameter. The callable should return either True or False.
1118ffe3c632Sopenharmony_ci
1119ffe3c632Sopenharmony_ci
1120ffe3c632Sopenharmony_ci  Example:
1121ffe3c632Sopenharmony_ci
1122ffe3c632Sopenharmony_ci  def myParamValidator(param):
1123ffe3c632Sopenharmony_ci    # Advanced logic here
1124ffe3c632Sopenharmony_ci    return True
1125ffe3c632Sopenharmony_ci
1126ffe3c632Sopenharmony_ci  mock_dao.DoSomething(Func(myParamValidator), true)
1127ffe3c632Sopenharmony_ci  """
1128ffe3c632Sopenharmony_ci
1129ffe3c632Sopenharmony_ci  def __init__(self, func):
1130ffe3c632Sopenharmony_ci    """Initialize.
1131ffe3c632Sopenharmony_ci
1132ffe3c632Sopenharmony_ci    Args:
1133ffe3c632Sopenharmony_ci      func: callable that takes one parameter and returns a bool
1134ffe3c632Sopenharmony_ci    """
1135ffe3c632Sopenharmony_ci
1136ffe3c632Sopenharmony_ci    self._func = func
1137ffe3c632Sopenharmony_ci
1138ffe3c632Sopenharmony_ci  def equals(self, rhs):
1139ffe3c632Sopenharmony_ci    """Test whether rhs passes the function test.
1140ffe3c632Sopenharmony_ci
1141ffe3c632Sopenharmony_ci    rhs is passed into func.
1142ffe3c632Sopenharmony_ci
1143ffe3c632Sopenharmony_ci    Args:
1144ffe3c632Sopenharmony_ci      rhs: any python object
1145ffe3c632Sopenharmony_ci
1146ffe3c632Sopenharmony_ci    Returns:
1147ffe3c632Sopenharmony_ci      the result of func(rhs)
1148ffe3c632Sopenharmony_ci    """
1149ffe3c632Sopenharmony_ci
1150ffe3c632Sopenharmony_ci    return self._func(rhs)
1151ffe3c632Sopenharmony_ci
1152ffe3c632Sopenharmony_ci  def __repr__(self):
1153ffe3c632Sopenharmony_ci    return str(self._func)
1154ffe3c632Sopenharmony_ci
1155ffe3c632Sopenharmony_ci
1156ffe3c632Sopenharmony_ciclass IgnoreArg(Comparator):
1157ffe3c632Sopenharmony_ci  """Ignore an argument.
1158ffe3c632Sopenharmony_ci
1159ffe3c632Sopenharmony_ci  This can be used when we don't care about an argument of a method call.
1160ffe3c632Sopenharmony_ci
1161ffe3c632Sopenharmony_ci  Example:
1162ffe3c632Sopenharmony_ci  # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163ffe3c632Sopenharmony_ci  mymock.CastMagic(3, IgnoreArg(), 'disappear')
1164ffe3c632Sopenharmony_ci  """
1165ffe3c632Sopenharmony_ci
1166ffe3c632Sopenharmony_ci  def equals(self, unused_rhs):
1167ffe3c632Sopenharmony_ci    """Ignores arguments and returns True.
1168ffe3c632Sopenharmony_ci
1169ffe3c632Sopenharmony_ci    Args:
1170ffe3c632Sopenharmony_ci      unused_rhs: any python object
1171ffe3c632Sopenharmony_ci
1172ffe3c632Sopenharmony_ci    Returns:
1173ffe3c632Sopenharmony_ci      always returns True
1174ffe3c632Sopenharmony_ci    """
1175ffe3c632Sopenharmony_ci
1176ffe3c632Sopenharmony_ci    return True
1177ffe3c632Sopenharmony_ci
1178ffe3c632Sopenharmony_ci  def __repr__(self):
1179ffe3c632Sopenharmony_ci    return '<IgnoreArg>'
1180ffe3c632Sopenharmony_ci
1181ffe3c632Sopenharmony_ci
1182ffe3c632Sopenharmony_ciclass MethodGroup(object):
1183ffe3c632Sopenharmony_ci  """Base class containing common behaviour for MethodGroups."""
1184ffe3c632Sopenharmony_ci
1185ffe3c632Sopenharmony_ci  def __init__(self, group_name):
1186ffe3c632Sopenharmony_ci    self._group_name = group_name
1187ffe3c632Sopenharmony_ci
1188ffe3c632Sopenharmony_ci  def group_name(self):
1189ffe3c632Sopenharmony_ci    return self._group_name
1190ffe3c632Sopenharmony_ci
1191ffe3c632Sopenharmony_ci  def __str__(self):
1192ffe3c632Sopenharmony_ci    return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1193ffe3c632Sopenharmony_ci
1194ffe3c632Sopenharmony_ci  def AddMethod(self, mock_method):
1195ffe3c632Sopenharmony_ci    raise NotImplementedError
1196ffe3c632Sopenharmony_ci
1197ffe3c632Sopenharmony_ci  def MethodCalled(self, mock_method):
1198ffe3c632Sopenharmony_ci    raise NotImplementedError
1199ffe3c632Sopenharmony_ci
1200ffe3c632Sopenharmony_ci  def IsSatisfied(self):
1201ffe3c632Sopenharmony_ci    raise NotImplementedError
1202ffe3c632Sopenharmony_ci
1203ffe3c632Sopenharmony_ciclass UnorderedGroup(MethodGroup):
1204ffe3c632Sopenharmony_ci  """UnorderedGroup holds a set of method calls that may occur in any order.
1205ffe3c632Sopenharmony_ci
1206ffe3c632Sopenharmony_ci  This construct is helpful for non-deterministic events, such as iterating
1207ffe3c632Sopenharmony_ci  over the keys of a dict.
1208ffe3c632Sopenharmony_ci  """
1209ffe3c632Sopenharmony_ci
1210ffe3c632Sopenharmony_ci  def __init__(self, group_name):
1211ffe3c632Sopenharmony_ci    super(UnorderedGroup, self).__init__(group_name)
1212ffe3c632Sopenharmony_ci    self._methods = []
1213ffe3c632Sopenharmony_ci
1214ffe3c632Sopenharmony_ci  def AddMethod(self, mock_method):
1215ffe3c632Sopenharmony_ci    """Add a method to this group.
1216ffe3c632Sopenharmony_ci
1217ffe3c632Sopenharmony_ci    Args:
1218ffe3c632Sopenharmony_ci      mock_method: A mock method to be added to this group.
1219ffe3c632Sopenharmony_ci    """
1220ffe3c632Sopenharmony_ci
1221ffe3c632Sopenharmony_ci    self._methods.append(mock_method)
1222ffe3c632Sopenharmony_ci
1223ffe3c632Sopenharmony_ci  def MethodCalled(self, mock_method):
1224ffe3c632Sopenharmony_ci    """Remove a method call from the group.
1225ffe3c632Sopenharmony_ci
1226ffe3c632Sopenharmony_ci    If the method is not in the set, an UnexpectedMethodCallError will be
1227ffe3c632Sopenharmony_ci    raised.
1228ffe3c632Sopenharmony_ci
1229ffe3c632Sopenharmony_ci    Args:
1230ffe3c632Sopenharmony_ci      mock_method: a mock method that should be equal to a method in the group.
1231ffe3c632Sopenharmony_ci
1232ffe3c632Sopenharmony_ci    Returns:
1233ffe3c632Sopenharmony_ci      The mock method from the group
1234ffe3c632Sopenharmony_ci
1235ffe3c632Sopenharmony_ci    Raises:
1236ffe3c632Sopenharmony_ci      UnexpectedMethodCallError if the mock_method was not in the group.
1237ffe3c632Sopenharmony_ci    """
1238ffe3c632Sopenharmony_ci
1239ffe3c632Sopenharmony_ci    # Check to see if this method exists, and if so, remove it from the set
1240ffe3c632Sopenharmony_ci    # and return it.
1241ffe3c632Sopenharmony_ci    for method in self._methods:
1242ffe3c632Sopenharmony_ci      if method == mock_method:
1243ffe3c632Sopenharmony_ci        # Remove the called mock_method instead of the method in the group.
1244ffe3c632Sopenharmony_ci        # The called method will match any comparators when equality is checked
1245ffe3c632Sopenharmony_ci        # during removal.  The method in the group could pass a comparator to
1246ffe3c632Sopenharmony_ci        # another comparator during the equality check.
1247ffe3c632Sopenharmony_ci        self._methods.remove(mock_method)
1248ffe3c632Sopenharmony_ci
1249ffe3c632Sopenharmony_ci        # If this group is not empty, put it back at the head of the queue.
1250ffe3c632Sopenharmony_ci        if not self.IsSatisfied():
1251ffe3c632Sopenharmony_ci          mock_method._call_queue.appendleft(self)
1252ffe3c632Sopenharmony_ci
1253ffe3c632Sopenharmony_ci        return self, method
1254ffe3c632Sopenharmony_ci
1255ffe3c632Sopenharmony_ci    raise UnexpectedMethodCallError(mock_method, self)
1256ffe3c632Sopenharmony_ci
1257ffe3c632Sopenharmony_ci  def IsSatisfied(self):
1258ffe3c632Sopenharmony_ci    """Return True if there are not any methods in this group."""
1259ffe3c632Sopenharmony_ci
1260ffe3c632Sopenharmony_ci    return len(self._methods) == 0
1261ffe3c632Sopenharmony_ci
1262ffe3c632Sopenharmony_ci
1263ffe3c632Sopenharmony_ciclass MultipleTimesGroup(MethodGroup):
1264ffe3c632Sopenharmony_ci  """MultipleTimesGroup holds methods that may be called any number of times.
1265ffe3c632Sopenharmony_ci
1266ffe3c632Sopenharmony_ci  Note: Each method must be called at least once.
1267ffe3c632Sopenharmony_ci
1268ffe3c632Sopenharmony_ci  This is helpful, if you don't know or care how many times a method is called.
1269ffe3c632Sopenharmony_ci  """
1270ffe3c632Sopenharmony_ci
1271ffe3c632Sopenharmony_ci  def __init__(self, group_name):
1272ffe3c632Sopenharmony_ci    super(MultipleTimesGroup, self).__init__(group_name)
1273ffe3c632Sopenharmony_ci    self._methods = set()
1274ffe3c632Sopenharmony_ci    self._methods_called = set()
1275ffe3c632Sopenharmony_ci
1276ffe3c632Sopenharmony_ci  def AddMethod(self, mock_method):
1277ffe3c632Sopenharmony_ci    """Add a method to this group.
1278ffe3c632Sopenharmony_ci
1279ffe3c632Sopenharmony_ci    Args:
1280ffe3c632Sopenharmony_ci      mock_method: A mock method to be added to this group.
1281ffe3c632Sopenharmony_ci    """
1282ffe3c632Sopenharmony_ci
1283ffe3c632Sopenharmony_ci    self._methods.add(mock_method)
1284ffe3c632Sopenharmony_ci
1285ffe3c632Sopenharmony_ci  def MethodCalled(self, mock_method):
1286ffe3c632Sopenharmony_ci    """Remove a method call from the group.
1287ffe3c632Sopenharmony_ci
1288ffe3c632Sopenharmony_ci    If the method is not in the set, an UnexpectedMethodCallError will be
1289ffe3c632Sopenharmony_ci    raised.
1290ffe3c632Sopenharmony_ci
1291ffe3c632Sopenharmony_ci    Args:
1292ffe3c632Sopenharmony_ci      mock_method: a mock method that should be equal to a method in the group.
1293ffe3c632Sopenharmony_ci
1294ffe3c632Sopenharmony_ci    Returns:
1295ffe3c632Sopenharmony_ci      The mock method from the group
1296ffe3c632Sopenharmony_ci
1297ffe3c632Sopenharmony_ci    Raises:
1298ffe3c632Sopenharmony_ci      UnexpectedMethodCallError if the mock_method was not in the group.
1299ffe3c632Sopenharmony_ci    """
1300ffe3c632Sopenharmony_ci
1301ffe3c632Sopenharmony_ci    # Check to see if this method exists, and if so add it to the set of
1302ffe3c632Sopenharmony_ci    # called methods.
1303ffe3c632Sopenharmony_ci
1304ffe3c632Sopenharmony_ci    for method in self._methods:
1305ffe3c632Sopenharmony_ci      if method == mock_method:
1306ffe3c632Sopenharmony_ci        self._methods_called.add(mock_method)
1307ffe3c632Sopenharmony_ci        # Always put this group back on top of the queue, because we don't know
1308ffe3c632Sopenharmony_ci        # when we are done.
1309ffe3c632Sopenharmony_ci        mock_method._call_queue.appendleft(self)
1310ffe3c632Sopenharmony_ci        return self, method
1311ffe3c632Sopenharmony_ci
1312ffe3c632Sopenharmony_ci    if self.IsSatisfied():
1313ffe3c632Sopenharmony_ci      next_method = mock_method._PopNextMethod();
1314ffe3c632Sopenharmony_ci      return next_method, None
1315ffe3c632Sopenharmony_ci    else:
1316ffe3c632Sopenharmony_ci      raise UnexpectedMethodCallError(mock_method, self)
1317ffe3c632Sopenharmony_ci
1318ffe3c632Sopenharmony_ci  def IsSatisfied(self):
1319ffe3c632Sopenharmony_ci    """Return True if all methods in this group are called at least once."""
1320ffe3c632Sopenharmony_ci    # NOTE(psycho): We can't use the simple set difference here because we want
1321ffe3c632Sopenharmony_ci    # to match different parameters which are considered the same e.g. IsA(str)
1322ffe3c632Sopenharmony_ci    # and some string. This solution is O(n^2) but n should be small.
1323ffe3c632Sopenharmony_ci    tmp = self._methods.copy()
1324ffe3c632Sopenharmony_ci    for called in self._methods_called:
1325ffe3c632Sopenharmony_ci      for expected in tmp:
1326ffe3c632Sopenharmony_ci        if called == expected:
1327ffe3c632Sopenharmony_ci          tmp.remove(expected)
1328ffe3c632Sopenharmony_ci          if not tmp:
1329ffe3c632Sopenharmony_ci            return True
1330ffe3c632Sopenharmony_ci          break
1331ffe3c632Sopenharmony_ci    return False
1332ffe3c632Sopenharmony_ci
1333ffe3c632Sopenharmony_ci
1334ffe3c632Sopenharmony_ciclass MoxMetaTestBase(type):
1335ffe3c632Sopenharmony_ci  """Metaclass to add mox cleanup and verification to every test.
1336ffe3c632Sopenharmony_ci
1337ffe3c632Sopenharmony_ci  As the mox unit testing class is being constructed (MoxTestBase or a
1338ffe3c632Sopenharmony_ci  subclass), this metaclass will modify all test functions to call the
1339ffe3c632Sopenharmony_ci  CleanUpMox method of the test class after they finish. This means that
1340ffe3c632Sopenharmony_ci  unstubbing and verifying will happen for every test with no additional code,
1341ffe3c632Sopenharmony_ci  and any failures will result in test failures as opposed to errors.
1342ffe3c632Sopenharmony_ci  """
1343ffe3c632Sopenharmony_ci
1344ffe3c632Sopenharmony_ci  def __init__(cls, name, bases, d):
1345ffe3c632Sopenharmony_ci    type.__init__(cls, name, bases, d)
1346ffe3c632Sopenharmony_ci
1347ffe3c632Sopenharmony_ci    # also get all the attributes from the base classes to account
1348ffe3c632Sopenharmony_ci    # for a case when test class is not the immediate child of MoxTestBase
1349ffe3c632Sopenharmony_ci    for base in bases:
1350ffe3c632Sopenharmony_ci      for attr_name in dir(base):
1351ffe3c632Sopenharmony_ci        d[attr_name] = getattr(base, attr_name)
1352ffe3c632Sopenharmony_ci
1353ffe3c632Sopenharmony_ci    for func_name, func in d.items():
1354ffe3c632Sopenharmony_ci      if func_name.startswith('test') and callable(func):
1355ffe3c632Sopenharmony_ci        setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1356ffe3c632Sopenharmony_ci
1357ffe3c632Sopenharmony_ci  @staticmethod
1358ffe3c632Sopenharmony_ci  def CleanUpTest(cls, func):
1359ffe3c632Sopenharmony_ci    """Adds Mox cleanup code to any MoxTestBase method.
1360ffe3c632Sopenharmony_ci
1361ffe3c632Sopenharmony_ci    Always unsets stubs after a test. Will verify all mocks for tests that
1362ffe3c632Sopenharmony_ci    otherwise pass.
1363ffe3c632Sopenharmony_ci
1364ffe3c632Sopenharmony_ci    Args:
1365ffe3c632Sopenharmony_ci      cls: MoxTestBase or subclass; the class whose test method we are altering.
1366ffe3c632Sopenharmony_ci      func: method; the method of the MoxTestBase test class we wish to alter.
1367ffe3c632Sopenharmony_ci
1368ffe3c632Sopenharmony_ci    Returns:
1369ffe3c632Sopenharmony_ci      The modified method.
1370ffe3c632Sopenharmony_ci    """
1371ffe3c632Sopenharmony_ci    def new_method(self, *args, **kwargs):
1372ffe3c632Sopenharmony_ci      mox_obj = getattr(self, 'mox', None)
1373ffe3c632Sopenharmony_ci      cleanup_mox = False
1374ffe3c632Sopenharmony_ci      if mox_obj and isinstance(mox_obj, Mox):
1375ffe3c632Sopenharmony_ci        cleanup_mox = True
1376ffe3c632Sopenharmony_ci      try:
1377ffe3c632Sopenharmony_ci        func(self, *args, **kwargs)
1378ffe3c632Sopenharmony_ci      finally:
1379ffe3c632Sopenharmony_ci        if cleanup_mox:
1380ffe3c632Sopenharmony_ci          mox_obj.UnsetStubs()
1381ffe3c632Sopenharmony_ci      if cleanup_mox:
1382ffe3c632Sopenharmony_ci        mox_obj.VerifyAll()
1383ffe3c632Sopenharmony_ci    new_method.__name__ = func.__name__
1384ffe3c632Sopenharmony_ci    new_method.__doc__ = func.__doc__
1385ffe3c632Sopenharmony_ci    new_method.__module__ = func.__module__
1386ffe3c632Sopenharmony_ci    return new_method
1387ffe3c632Sopenharmony_ci
1388ffe3c632Sopenharmony_ci
1389ffe3c632Sopenharmony_ciclass MoxTestBase(unittest.TestCase):
1390ffe3c632Sopenharmony_ci  """Convenience test class to make stubbing easier.
1391ffe3c632Sopenharmony_ci
1392ffe3c632Sopenharmony_ci  Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393ffe3c632Sopenharmony_ci  want this. Also automatically unsets any stubs and verifies that all mock
1394ffe3c632Sopenharmony_ci  methods have been called at the end of each test, eliminating boilerplate
1395ffe3c632Sopenharmony_ci  code.
1396ffe3c632Sopenharmony_ci  """
1397ffe3c632Sopenharmony_ci
1398ffe3c632Sopenharmony_ci  __metaclass__ = MoxMetaTestBase
1399ffe3c632Sopenharmony_ci
1400ffe3c632Sopenharmony_ci  def setUp(self):
1401ffe3c632Sopenharmony_ci    self.mox = Mox()
1402