12 "github.com/davecgh/go-spew/spew"
13 "github.com/pmezard/go-difflib/difflib"
14 "github.com/stretchr/objx"
15 "github.com/stretchr/testify/assert"
18 // TestingT is an interface wrapper around *testing.T
19 type TestingT interface {
20 Logf(format string, args ...interface{})
21 Errorf(format string, args ...interface{})
29 // Call represents a method call and is used for setting expectations,
30 // as well as recording activity.
34 // The name of the method that was or will be called.
37 // Holds the arguments of the method.
40 // Holds the arguments that should be returned when
41 // this method is called.
42 ReturnArguments Arguments
44 // The number of times to return the return arguments when setting
45 // expectations. 0 means to always return the value.
48 // Amount of times this call has been called
51 // Holds a channel that will be used to block the Return until it either
52 // receives a message or is closed. nil means it returns immediately.
53 WaitFor <-chan time.Time
55 // Holds a handler used to manipulate arguments content that are passed by
56 // reference. It's useful when mocking methods such as unmarshalers or
61 func newCall(parent *Mock, methodName string, methodArguments ...interface{}) *Call {
65 Arguments: methodArguments,
66 ReturnArguments: make([]interface{}, 0),
73 func (c *Call) lock() {
77 func (c *Call) unlock() {
78 c.Parent.mutex.Unlock()
81 // Return specifies the return arguments for the expectation.
83 // Mock.On("DoSomething").Return(errors.New("failed"))
84 func (c *Call) Return(returnArguments ...interface{}) *Call {
88 c.ReturnArguments = returnArguments
93 // Once indicates that that the mock should only return the value once.
95 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
96 func (c *Call) Once() *Call {
100 // Twice indicates that that the mock should only return the value twice.
102 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
103 func (c *Call) Twice() *Call {
107 // Times indicates that that the mock should only return the indicated number
110 // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
111 func (c *Call) Times(i int) *Call {
118 // WaitUntil sets the channel that will block the mock's return until its closed
119 // or a message is received.
121 // Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
122 func (c *Call) WaitUntil(w <-chan time.Time) *Call {
129 // After sets how long to block until the call returns
131 // Mock.On("MyMethod", arg1, arg2).After(time.Second)
132 func (c *Call) After(d time.Duration) *Call {
133 return c.WaitUntil(time.After(d))
136 // Run sets a handler to be called before returning. It can be used when
137 // mocking a method such as unmarshalers that takes a pointer to a struct and
138 // sets properties in such struct
140 // Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}").Return().Run(func(args Arguments) {
141 // arg := args.Get(0).(*map[string]interface{})
142 // arg["foo"] = "bar"
144 func (c *Call) Run(fn func(args Arguments)) *Call {
151 // On chains a new expectation description onto the mocked interface. This
152 // allows syntax like.
155 // On("MyMethod", 1).Return(nil).
156 // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
157 func (c *Call) On(methodName string, arguments ...interface{}) *Call {
158 return c.Parent.On(methodName, arguments...)
161 // Mock is the workhorse used to track activity on another object.
162 // For an example of its usage, refer to the "Example Usage" section at the top
165 // Represents the calls that are expected of
167 ExpectedCalls []*Call
169 // Holds the calls that were made to this mocked object.
172 // TestData holds any data that might be useful for testing. Testify ignores
173 // this data completely allowing you to do whatever you like with it.
179 // TestData holds any data that might be useful for testing. Testify ignores
180 // this data completely allowing you to do whatever you like with it.
181 func (m *Mock) TestData() objx.Map {
183 if m.testData == nil {
184 m.testData = make(objx.Map)
194 // On starts a description of an expectation of the specified method
197 // Mock.On("MyMethod", arg1, arg2)
198 func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
199 for _, arg := range arguments {
200 if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
201 panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
206 defer m.mutex.Unlock()
207 c := newCall(m, methodName, arguments...)
208 m.ExpectedCalls = append(m.ExpectedCalls, c)
213 // Recording and responding to activity
216 func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
217 for i, call := range m.ExpectedCalls {
218 if call.Method == method && call.Repeatability > -1 {
220 _, diffCount := call.Arguments.Diff(arguments)
230 func (m *Mock) findClosestCall(method string, arguments ...interface{}) (bool, *Call) {
232 var closestCall *Call
234 for _, call := range m.expectedCalls() {
235 if call.Method == method {
237 _, tempDiffCount := call.Arguments.Diff(arguments)
238 if tempDiffCount < diffCount || diffCount == 0 {
239 diffCount = tempDiffCount
246 if closestCall == nil {
250 return true, closestCall
253 func callString(method string, arguments Arguments, includeArgumentValues bool) string {
255 var argValsString string
256 if includeArgumentValues {
258 for argIndex, arg := range arguments {
259 argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
261 argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
264 return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
267 // Called tells the mock object that a method has been called, and gets an array
268 // of arguments to return. Panics if the call is unexpected (i.e. not preceded by
269 // appropriate .On .Return() calls)
270 // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
271 func (m *Mock) Called(arguments ...interface{}) Arguments {
272 // get the calling function's name
273 pc, _, _, ok := runtime.Caller(1)
275 panic("Couldn't get the caller information")
277 functionPath := runtime.FuncForPC(pc).Name()
278 //Next four lines are required to use GCCGO function naming conventions.
279 //For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
280 //uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
281 //With GCCGO we need to remove interface information starting from pN<dd>.
282 re := regexp.MustCompile("\\.pN\\d+_")
283 if re.MatchString(functionPath) {
284 functionPath = re.Split(functionPath, -1)[0]
286 parts := strings.Split(functionPath, ".")
287 functionName := parts[len(parts)-1]
288 return m.MethodCalled(functionName, arguments...)
291 // MethodCalled tells the mock object that the given method has been called, and gets
292 // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
293 // by appropriate .On .Return() calls)
294 // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
295 func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
297 found, call := m.findExpectedCall(methodName, arguments...)
300 // we have to fail here - because we don't know what to do
301 // as the return arguments. This is because:
303 // a) this is a totally unexpected call to this method,
304 // b) the arguments are not what was expected, or
305 // c) the developer has forgotten to add an accompanying On...Return pair.
307 closestFound, closestCall := m.findClosestCall(methodName, arguments...)
311 panic(fmt.Sprintf("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\n", callString(methodName, arguments, true), callString(methodName, closestCall.Arguments, true), diffArguments(arguments, closestCall.Arguments)))
313 panic(fmt.Sprintf("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo()))
318 case call.Repeatability == 1:
319 call.Repeatability = -1
322 case call.Repeatability > 1:
326 case call.Repeatability == 0:
331 m.Calls = append(m.Calls, *newCall(m, methodName, arguments...))
334 // block if specified
335 if call.WaitFor != nil {
348 returnArgs := call.ReturnArguments
358 type assertExpectationser interface {
359 AssertExpectations(TestingT) bool
362 // AssertExpectationsForObjects asserts that everything specified with On and Return
363 // of the specified objects was in fact called as expected.
365 // Calls may have occurred in any order.
366 func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
367 for _, obj := range testObjects {
368 if m, ok := obj.(Mock); ok {
369 t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
372 m := obj.(assertExpectationser)
373 if !m.AssertExpectations(t) {
380 // AssertExpectations asserts that everything specified with On and Return was
381 // in fact called as expected. Calls may have occurred in any order.
382 func (m *Mock) AssertExpectations(t TestingT) bool {
384 defer m.mutex.Unlock()
385 var somethingMissing bool
386 var failedExpectations int
388 // iterate through each expectation
389 expectedCalls := m.expectedCalls()
390 for _, expectedCall := range expectedCalls {
391 if !m.methodWasCalled(expectedCall.Method, expectedCall.Arguments) && expectedCall.totalCalls == 0 {
392 somethingMissing = true
394 t.Logf("\u274C\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
396 if expectedCall.Repeatability > 0 {
397 somethingMissing = true
400 t.Logf("\u2705\t%s(%s)", expectedCall.Method, expectedCall.Arguments.String())
405 if somethingMissing {
406 t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
409 return !somethingMissing
412 // AssertNumberOfCalls asserts that the method was called expectedCalls times.
413 func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
415 defer m.mutex.Unlock()
417 for _, call := range m.calls() {
418 if call.Method == methodName {
422 return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
425 // AssertCalled asserts that the method was called.
426 // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
427 func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
429 defer m.mutex.Unlock()
430 if !assert.True(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method should have been called with %d argument(s), but was not.", methodName, len(arguments))) {
431 t.Logf("%v", m.expectedCalls())
437 // AssertNotCalled asserts that the method was not called.
438 // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
439 func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
441 defer m.mutex.Unlock()
442 if !assert.False(t, m.methodWasCalled(methodName, arguments), fmt.Sprintf("The \"%s\" method was called with %d argument(s), but should NOT have been.", methodName, len(arguments))) {
443 t.Logf("%v", m.expectedCalls())
449 func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
450 for _, call := range m.calls() {
451 if call.Method == methodName {
453 _, differences := Arguments(expected).Diff(call.Arguments)
455 if differences == 0 {
456 // found the expected call
462 // we didn't find the expected call
466 func (m *Mock) expectedCalls() []*Call {
467 return append([]*Call{}, m.ExpectedCalls...)
470 func (m *Mock) calls() []Call {
471 return append([]Call{}, m.Calls...)
478 // Arguments holds an array of method arguments or return values.
479 type Arguments []interface{}
482 // Anything is used in Diff and Assert when the argument being tested
483 // shouldn't be taken into consideration.
484 Anything string = "mock.Anything"
487 // AnythingOfTypeArgument is a string that contains the type of an argument
488 // for use when type checking. Used in Diff and Assert.
489 type AnythingOfTypeArgument string
491 // AnythingOfType returns an AnythingOfTypeArgument object containing the
492 // name of the type to check for. Used in Diff and Assert.
495 // Assert(t, AnythingOfType("string"), AnythingOfType("int"))
496 func AnythingOfType(t string) AnythingOfTypeArgument {
497 return AnythingOfTypeArgument(t)
500 // argumentMatcher performs custom argument matching, returning whether or
501 // not the argument is matched by the expectation fixture function.
502 type argumentMatcher struct {
503 // fn is a function which accepts one argument, and returns a bool.
507 func (f argumentMatcher) Matches(argument interface{}) bool {
508 expectType := f.fn.Type().In(0)
510 if reflect.TypeOf(argument).AssignableTo(expectType) {
511 result := f.fn.Call([]reflect.Value{reflect.ValueOf(argument)})
512 return result[0].Bool()
517 func (f argumentMatcher) String() string {
518 return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).Name())
521 // MatchedBy can be used to match a mock call based on only certain properties
522 // from a complex struct or some calculation. It takes a function that will be
523 // evaluated with the called argument and will return true when there's a match
524 // and false otherwise.
527 // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
529 // |fn|, must be a function accepting a single argument (of the expected type)
530 // which returns a bool. If |fn| doesn't match the required signature,
531 // MatchedBy() panics.
532 func MatchedBy(fn interface{}) argumentMatcher {
533 fnType := reflect.TypeOf(fn)
535 if fnType.Kind() != reflect.Func {
536 panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
538 if fnType.NumIn() != 1 {
539 panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
541 if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
542 panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
545 return argumentMatcher{fn: reflect.ValueOf(fn)}
548 // Get Returns the argument at the specified index.
549 func (args Arguments) Get(index int) interface{} {
550 if index+1 > len(args) {
551 panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
556 // Is gets whether the objects match the arguments specified.
557 func (args Arguments) Is(objects ...interface{}) bool {
558 for i, obj := range args {
559 if obj != objects[i] {
566 // Diff gets a string describing the differences between the arguments
567 // and the specified objects.
569 // Returns the diff string and number of differences found.
570 func (args Arguments) Diff(objects []interface{}) (string, int) {
575 var maxArgCount = len(args)
576 if len(objects) > maxArgCount {
577 maxArgCount = len(objects)
580 for i := 0; i < maxArgCount; i++ {
581 var actual, expected interface{}
583 if len(objects) <= i {
590 expected = "(Missing)"
595 if matcher, ok := expected.(argumentMatcher); ok {
596 if matcher.Matches(actual) {
597 output = fmt.Sprintf("%s\t%d: \u2705 %s matched by %s\n", output, i, actual, matcher)
600 output = fmt.Sprintf("%s\t%d: \u2705 %s not matched by %s\n", output, i, actual, matcher)
602 } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {
605 if reflect.TypeOf(actual).Name() != string(expected.(AnythingOfTypeArgument)) && reflect.TypeOf(actual).String() != string(expected.(AnythingOfTypeArgument)) {
608 output = fmt.Sprintf("%s\t%d: \u274C type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actual)
615 if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
617 output = fmt.Sprintf("%s\t%d: \u2705 %s == %s\n", output, i, actual, expected)
621 output = fmt.Sprintf("%s\t%d: \u274C %s != %s\n", output, i, actual, expected)
627 if differences == 0 {
628 return "No differences.", differences
631 return output, differences
635 // Assert compares the arguments with the specified objects and fails if
636 // they do not exactly match.
637 func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
639 // get the differences
640 diff, diffCount := args.Diff(objects)
646 // there are differences... report them...
648 t.Errorf("%sArguments do not match.", assert.CallerInfo())
654 // String gets the argument at the specified index. Panics if there is no argument, or
655 // if the argument is of the wrong type.
657 // If no index is provided, String() returns a complete string representation
659 func (args Arguments) String(indexOrNil ...int) string {
661 if len(indexOrNil) == 0 {
662 // normal String() method - return a string representation of the args
664 for _, arg := range args {
665 argsStr = append(argsStr, fmt.Sprintf("%s", reflect.TypeOf(arg)))
667 return strings.Join(argsStr, ",")
668 } else if len(indexOrNil) == 1 {
669 // Index has been specified - get the argument at that index
670 var index = indexOrNil[0]
673 if s, ok = args.Get(index).(string); !ok {
674 panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
679 panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
683 // Int gets the argument at the specified index. Panics if there is no argument, or
684 // if the argument is of the wrong type.
685 func (args Arguments) Int(index int) int {
688 if s, ok = args.Get(index).(int); !ok {
689 panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
694 // Error gets the argument at the specified index. Panics if there is no argument, or
695 // if the argument is of the wrong type.
696 func (args Arguments) Error(index int) error {
697 obj := args.Get(index)
703 if s, ok = obj.(error); !ok {
704 panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
709 // Bool gets the argument at the specified index. Panics if there is no argument, or
710 // if the argument is of the wrong type.
711 func (args Arguments) Bool(index int) bool {
714 if s, ok = args.Get(index).(bool); !ok {
715 panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
720 func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
721 t := reflect.TypeOf(v)
724 if k == reflect.Ptr {
731 func diffArguments(expected Arguments, actual Arguments) string {
732 if len(expected) != len(actual) {
733 return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
736 for x := range expected {
737 if diffString := diff(expected[x], actual[x]); diffString != "" {
738 return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
745 // diff returns a diff of both values as long as both are of the same type and
746 // are a struct, map, slice or array. Otherwise it returns an empty string.
747 func diff(expected interface{}, actual interface{}) string {
748 if expected == nil || actual == nil {
752 et, ek := typeAndKind(expected)
753 at, _ := typeAndKind(actual)
759 if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
763 e := spewConfig.Sdump(expected)
764 a := spewConfig.Sdump(actual)
766 diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
767 A: difflib.SplitLines(e),
768 B: difflib.SplitLines(a),
769 FromFile: "Expected",
779 var spewConfig = spew.ConfigState{
781 DisablePointerAddresses: true,
782 DisableCapacities: true,