Static Mocks for Groovy

After experimenting with Groovy's excellent support for mock objects, I was a little disappointed to find out that mocking of static methods was not supported. I decided to write my own implementation and share it with others.

First, let's start with the enhanced StaticMockFor class itself:

/**
 * Use this class to mock static methods in a similar way to 
 * Groovy's MockFor implementation (which does not support
 * mocking of static methods).
 * 
 * Note that this implementation does not support demand
 * ranges, unlike Groovy's MockFor.
 * 
 * @author Gerardo Viedma
 */
class StaticMockFor {
   
   final Class clazz
   //final Demand demand = new Demand()
   final OrderedDemand demand = new OrderedDemand()
   
   StaticMockFor(Class clazz) {
      this.clazz = clazz
   }

   def use(Closure clo) {
      try {
         // override with mock behavior
         clazz.metaClass.static.invokeMethod = { String name, args ->
            demand.invoke(name, args)
         }         
         // execute the closure
         clo()
         // verify that we satisfied all the demands
         demand.verify()
      }
      finally {
         // reset to the original invokeMethod after verifying mocks
         clazz.metaClass.static.invokeMethod = { String name, args ->
            def original = clazz.metaClass.getStaticMetaMethod(name, args)
            if (original)
               original.invoke(name, args)
            else
               throw new RuntimeException("Method $name not found!")
         }
      }
   }
}

Note that the StaticMockFor implementation relies on the OrderedDemand type to verify cardinality and ordering of method calls to the mocked class. OrderedDemand leverages Groovy's Expando mechanism to add dynamic behavior allowing us to override the mocked static methods.

/**
 * Encapsulates demands for instances of StaticMockFor.
 * The implementation is based on a queue of demanded
 * method invokations that have to be met in order
 * and with the specified cardinality.
 * 
 * @author Gerardo Viedma
 */
class OrderedDemand {

   final Map closures = [:]
   // maintain calls in order 
   final Queue calls = [] as Queue
   // maps method name to an integer count
   final Map actualCount = [:]
   // maps method name to a range
   final Map expectedCount = [:]
   
   // add dynamic method behavior
   static {
      OrderedDemand.metaClass.invokeMethod = { String name, args ->
         def metaMethod = OrderedDemand.metaClass.getMetaMethod(name, args)
         // pass on calls to invoke and verify
         if(metaMethod) {
            return metaMethod.invoke(delegate,args)
         }
         // add methods dynamically and keep track of their counts
         def range
         Closure clo
         if (args.size() == 1) {
            range = 1..1
            clo = args[0]
         }
         else {
            range = args[0]
            if (!(range instanceof Range))
               range = range..range
            clo = args[1]
         }
         // repeat the methods as many times as necessary
         calls.add(name)
         actualCount.put(name, 0)
         expectedCount.put(name, range)
         closures.put(name, clo)
      }
   }
   
   /*
    * Invokes a method, removing it from the demand queue
    * if it was the next invokation in the demand queue.
    * Otherwise, throws an assertion failure.
    */
   def invoke(String name, args) {
      if (calls.isEmpty())
         throw new RuntimeException("Did not expect any calls to $name")
      def head = calls.peek()   
      if (name == head) {
         def rslt = closures[name](args)
         // update the count
         def actual = actualCount.get(name) + 1
         actualCount.put(name, actual)
         return rslt
      }
      else {
         verify()
         // updated the head, so call recursively
         invoke(name, args)
      }
   }

   /*
    * Verifies that all demands have been met by ensuring
    * that all demanded methods were invoked.
    */
   def verify() {
      def head = calls.remove()
      def expected = expectedCount.get(head)
      def actual = actualCount.get(head)
      if (!expected.contains(actual)) {
         throw new RuntimeException("Incorrect number of calls to $head")
      }      
   }
}

Finally, we can demonstrate usage of the StaticMockFor class and verify its behavior by writing some unit tests:

/**
 * Tests functionality of StaticMockFor instances.
 * 
 * @author Gerardo Viedma
 */
class StaticMockForTest extends GroovyTestCase {
   
   static final NOT_SO_RANDOM = 0.5
   
   void testSingleCall() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      mathMock.use {
         assertEquals Math.random(), NOT_SO_RANDOM
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testMultipleCalls() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(3) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      mathMock.use {
         assertEquals Math.random(), NOT_SO_RANDOM
         assertEquals Math.random(), NOT_SO_RANDOM
         assertEquals Math.random(), NOT_SO_RANDOM
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testCallRange() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(2..4) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      mathMock.use {
         assertEquals Math.random(), NOT_SO_RANDOM
         assertEquals Math.random(), NOT_SO_RANDOM
         assertEquals Math.random(), NOT_SO_RANDOM
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testMissingCalls() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(4) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      shouldFail {
         mathMock.use {
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
         }
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testMissingCallsInRange() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(2..4) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      shouldFail {
         mathMock.use {
            assertEquals Math.random(), NOT_SO_RANDOM
         }
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testExceededCalls() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(2) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      shouldFail {
         mathMock.use {
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
         }
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }
   
   void testExceededCallsInRange() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random(1..2) { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      shouldFail {
         mathMock.use {
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
         }
      }
      assertFalse Math.random() == NOT_SO_RANDOM
   }   
   
   void testCorrectCallOrder() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      mathMock.demand.abs { a ->
         println 'Mocking Math.abs()'
         (a > 0) ? a : -a
      }      
      mathMock.use {
         assertEquals Math.random(), NOT_SO_RANDOM
         assertEquals Math.abs(NOT_SO_RANDOM), NOT_SO_RANDOM
      }
   }
   
   void testWrongCallOrder() {
      def mathMock = new StaticMockFor(Math)
      mathMock.demand.random { 
         println 'Mocking Math.random()'
         NOT_SO_RANDOM
      }
      mathMock.demand.abs { a ->
         println 'Mocking Math.abs()'
         (a > 0) ? a : -a
      }   
      shouldFail {
         mathMock.use {
            assertEquals Math.abs(NOT_SO_RANDOM), NOT_SO_RANDOM
            assertEquals Math.random(), NOT_SO_RANDOM
         }   
      }
   }
   
}