OSDN Git Service

Use correct typeclass in compare_tests
[transunit/transunit.git] / transunit.compare.m
1 %   Copyright (C) 2018-2019 Alaskan Emily, Transnat Games
2 %
3 %   This software is provided 'as-is', without any express or implied
4 %   warranty.  In no event will the authors be held liable for any damages
5 %   arising from the use of this software.
6 %
7 %   Permission is granted to anyone to use this software for any purpose,
8 %   including commercial applications, and to alter it and redistribute it
9 %   freely, subject to the following restrictions:
10 %
11 %   1. The origin of this software must not be misrepresented; you must not
12 %      claim that you wrote the original software. If you use this software
13 %      in a product, an acknowledgment in the product documentation would be
14 %      appreciated but is not required.
15 %   2. Altered source versions must be plainly marked as such, and must not be
16 %      misrepresented as being the original software.
17 %  3. This notice may not be removed or altered from any source distribution.
18
19 :- module transunit.compare.
20
21 %==============================================================================%
22 % General components for the unit test framework.
23 % I know this isn't great. But it has no dependencies, and there are not a lot
24 % of prebuilt solutions for Mercury.
25 :- interface.
26 %==============================================================================%
27
28 :- use_module bool.
29 :- use_module rbtree.
30 :- use_module array.
31 :- use_module array2d.
32 :- use_module set.
33
34 %------------------------------------------------------------------------------%
35
36 :- instance to_string(int).
37 :- instance to_string(string).
38 :- instance to_string(float).
39 :- instance to_string(bool.bool).
40 :- instance to_string(maybe.maybe(T)) <= to_string(T).
41
42 %------------------------------------------------------------------------------%
43
44 :- instance compare(list(T)) <= (compare(T), to_string(T)).
45 :- instance compare(set.set(T)) <= (compare(T), to_string(T)).
46 %:- instance compare(rbtree.rbtree(K, V)) <= (compare(V), to_string(K), to_string(V)).
47 %:- instance compare(tree.tree(K, V)) <= (compare(V), to_string(K), to_string(V)).
48 :- instance compare(int).
49 :- instance compare(string).
50 :- instance compare(float).
51 :- instance compare(bool.bool).
52 :- instance compare(maybe.maybe(T)) <= (to_string(T), compare(T)).
53 :- instance compare(array.array(T)) <= (to_string(T), compare(T)).
54 :- instance compare(array2d.array2d(T)) <= (to_string(T), compare(T)).
55
56 %------------------------------------------------------------------------------%
57
58 :- func generic_compare(T, T) = maybe.maybe_error <= to_string(T).
59
60 %------------------------------------------------------------------------------%
61
62 :- func simple_compare(T, T) = maybe.maybe_error.
63
64 %------------------------------------------------------------------------------%
65
66 :- func negate(float) = float.
67
68 %------------------------------------------------------------------------------%
69 % float_equals(A, B)
70 :- pred float_equals(float, float).
71 :- mode float_equals(in, in) is semidet.
72 :- mode float_equals(di, di) is semidet.
73
74 %------------------------------------------------------------------------------%
75 % Promise the associativity of float comparisons
76 :- promise all[A, B] (
77     float_equals(A, B) <=> float_equals(B, A)
78 ).
79
80 %------------------------------------------------------------------------------%
81
82 :- promise all[A, B] (
83     float_equals(A, B) <=> float_equals(negate(A), negate(B))
84 ).
85
86 %------------------------------------------------------------------------------%
87
88 :- promise all[A, B] (
89     (negate(A) = B) <=> (negate(B) = A)
90 ).
91
92 %------------------------------------------------------------------------------%
93
94 :- promise all[A, B] (
95     some [C] (negate(A) = C, negate(B) = C, A = B)
96 ).
97
98 %------------------------------------------------------------------------------%
99 % float_equals(A, B, Epsilon)
100 :- pred float_equals(float, float, float).
101 :- mode float_equals(in, in, in) is semidet.
102 :- mode float_equals(di, di, in) is semidet.
103
104 %------------------------------------------------------------------------------%
105
106 :- promise all[A, B, Epsilon] (
107     float_equals(A, B, Epsilon) <=> float_equals(B, A, Epsilon)
108 ).
109
110 %==============================================================================%
111 :- implementation.
112 %==============================================================================%
113
114 :- import_module float.
115 :- use_module int.
116 :- use_module string.
117 :- use_module std_util.
118
119 %------------------------------------------------------------------------------%
120
121 :- instance to_string(int) where [
122     func(to_string/1) is string.from_int
123 ].
124
125 :- instance to_string(string) where [
126     func(to_string/1) is std_util.id
127 ].
128
129 :- instance to_string(float) where [
130     func(to_string/1) is string.from_float
131 ].
132
133 :- instance to_string(bool.bool) where [
134     (to_string(bool.yes) = "bool.yes"),
135     (to_string(bool.no) = "bool.no")
136 ].
137
138 :- instance to_string(maybe.maybe(T)) <= to_string(T) where [
139     (to_string(maybe.yes(That)) = to_string(That)),
140     (to_string(maybe.no) = "maybe.no")
141 ].
142
143 %------------------------------------------------------------------------------%
144
145 generic_compare(A, B) = Result :-
146     ( if
147         A = B 
148     then
149         Result = maybe.ok
150     else
151         Message = string.join_list(" != ", map(to_string, [A|[B|[]]])),
152         Result = maybe.error(Message)
153     ).
154
155 %------------------------------------------------------------------------------%
156
157 simple_compare(A, B) = Result :-
158     ( A = B -> Result = maybe.ok ; Result = maybe.error("Not equal") ).
159
160 %------------------------------------------------------------------------------%
161
162 :- pred accumulate_mismatch(T, T, list(string), list(string), int, int)
163     <= compare(T).
164 :- mode accumulate_mismatch(in, in, in, out, in, out) is det.
165
166 accumulate_mismatch(A, B, !List, I, int.plus(I, 1)) :-
167     compare(A, B) = MaybeResult,
168     (
169         MaybeResult = maybe.ok
170     ;
171         MaybeResult = maybe.error(Error),
172         string.append("Element ", string.from_int(I), Prefix),
173         string.append(string.append(Prefix, "\t: "), Error, Message),
174         list.cons(Message, !List)
175     ).
176
177 %------------------------------------------------------------------------------%
178
179 :- instance compare(list(T)) <= (compare(T), to_string(T)) where [
180     ( compare(A, B) = Result :-
181         list.length(A, ALen), list.length(B, BLen),
182         generic_compare(ALen, BLen) = LenCompare,
183         (
184             LenCompare = maybe.ok,
185             list.foldl2_corresponding(accumulate_mismatch, A, B, [], Errors, 0, _),
186             ( if
187                 list.is_empty(Errors)
188             then
189                 Result = maybe.ok
190             else
191                 Result = maybe.error(string.join_list("\n", Errors))
192             )
193         ;
194             LenCompare = maybe.error(Error),
195             Result = maybe.error(string.append("List length ", Error))
196         )
197     )
198 ].
199
200 :- instance compare(set.set(T)) <= (compare(T), to_string(T)) where [
201     ( compare(A, B) = Result :-
202         set.count(A, ALen), set.count(B, BLen),
203         generic_compare(ALen, BLen) = LenCompare,
204         (
205             LenCompare = maybe.ok,
206             ( set.to_sorted_list(A, AList) & set.to_sorted_list(B, BList) ),
207             compare(AList, BList) = Result
208         ;
209             LenCompare = maybe.error(Error),
210             Result = maybe.error(string.append("List length ", Error))
211         )
212     )
213 ].
214
215 %:- instance compare(rbtree.rbtree(K, V)) <= (compare(V), to_string(K), to_string(V)) where [
216 %].
217
218 %:- instance compare(tree.tree(K, V)) <= (compare(V), to_string(K), to_string(V)) where [
219 %].
220
221 :- instance compare(int) where [
222     func(compare/2) is generic_compare
223 ].
224
225 :- instance compare(string) where [
226     ( compare(A, B) = Result :-
227         ( A = B -> Result = maybe.ok
228         ; Result = maybe.error(string.join_list(" != ", [A|[B|[]]])) )
229     )
230 ].
231
232 :- instance compare(float) where [
233     ( compare(A, B) = Result :-
234         ( float_equals(A, B) -> Result = maybe.ok
235         ; Message = string.join_list(" != ", map(string.from_float, [A|[B|[]]])),
236           Result = maybe.error(Message) )
237     )
238 ].
239
240 :- instance compare(bool.bool) where [
241     ( compare(bool.yes, bool.yes) = maybe.ok ),
242     ( compare(bool.no, bool.no) = maybe.ok ),
243     ( compare(bool.yes, bool.no) = maybe.error("bool.yes != bool.no") ),
244     ( compare(bool.no, bool.yes) = maybe.error("bool.no != bool.yes") )
245 ].
246
247 :- instance compare(maybe.maybe(T)) <= (to_string(T), compare(T)) where [
248     func(compare/2) is generic_compare
249 ].
250
251 :- instance compare(array.array(T)) <= (to_string(T), compare(T)) where [
252     ( compare(A, B) = Result :-
253         array.size(A, ALen), array.size(B, BLen),
254         generic_compare(ALen, BLen) = LenCompare,
255         (
256             LenCompare = maybe.ok,
257             ( array.to_list(A, AList) & array.to_list(B, BList) ),
258             compare(AList, BList) = Result
259         ;
260             LenCompare = maybe.error(Error),
261             Result = maybe.error(string.append("Array length ", Error))
262         )
263     )
264 ].
265
266 :- instance compare(array2d.array2d(T)) <= (to_string(T), compare(T)) where [
267     ( compare(A, B) = Result :-
268         array2d.bounds(A, AW, AH), array2d.bounds(B, BW, BH),
269         generic_compare(AW, BW) = WCompare,
270         generic_compare(AH, BH) = HCompare,
271         (
272             WCompare = maybe.ok,
273             HCompare = maybe.ok,
274             % Kind of silly. Join the lists.
275             (
276               ( array2d.lists(A) = ALists,
277                 list.foldl(list.append, ALists, []) = AList ) &
278               ( array2d.lists(B) = BLists,
279                 list.foldl(list.append, BLists, []) = BList )
280             ),
281             compare(AList, BList) = Result
282         ;
283             WCompare = maybe.ok,
284             HCompare = maybe.error(Error),
285             Result = maybe.error(string.append("Array2D height ", Error))
286         ;
287             WCompare = maybe.error(Error),
288             HCompare = maybe.ok,
289             Result = maybe.error(string.append("Array2D width ", Error))
290         ;
291             WCompare = maybe.error(WError),
292             HCompare = maybe.error(HError),
293             string.append("Array2D width ", WError, W),
294             string.append("Array2D height ", HError, H),
295             Result = maybe.error(string.join_list("\n", [W|[H|[]]]))
296         )
297     )
298 ].
299
300 %------------------------------------------------------------------------------%
301
302 negate(X) = -X.
303
304 %------------------------------------------------------------------------------%
305
306 float_equals(A, B) :-
307     abs(A - B) =< float.epsilon.
308
309 float_equals(A, B, Epsilon) :-
310     abs(A - B) =< Epsilon.