OSDN Git Service

[globalisel][tablegen] Add support for C++ predicates on PatFrags and use it to suppo...
[android-x86/external-llvm.git] / utils / update_mir_test_checks.py
1 #!/usr/bin/env python
2
3 """Updates FileCheck checks in MIR tests.
4
5 This script is a utility to update MIR based tests with new FileCheck
6 patterns.
7
8 The checks added by this script will cover the entire body of each
9 function it handles. Virtual registers used are given names via
10 FileCheck patterns, so if you do want to check a subset of the body it
11 should be straightforward to trim out the irrelevant parts. None of
12 the YAML metadata will be checked, other than function names.
13
14 If there are multiple llc commands in a test, the full set of checks
15 will be repeated for each different check pattern. Checks for patterns
16 that are common between different commands will be left as-is by
17 default, or removed if the --remove-common-prefixes flag is provided.
18 """
19
20 from __future__ import print_function
21
22 import argparse
23 import collections
24 import os
25 import re
26 import subprocess
27 import sys
28
29 from UpdateTestChecks import common
30
31 MIR_FUNC_NAME_RE = re.compile(r' *name: *(?P<func>[A-Za-z0-9_.-]+)')
32 MIR_BODY_BEGIN_RE = re.compile(r' *body: *\|')
33 MIR_BASIC_BLOCK_RE = re.compile(r' *bb\.[0-9]+.*:$')
34 VREG_RE = re.compile(r'(%[0-9]+)(?::[a-z0-9_]+)?(?:\([<>a-z0-9 ]+\))?')
35 VREG_DEF_RE = re.compile(
36     r'^ *(?P<vregs>{0}(?:, {0})*) '
37     r'= (?P<opcode>[A-Zt][A-Za-z0-9_]+)'.format(VREG_RE.pattern))
38 MIR_PREFIX_DATA_RE = re.compile(r'^ *(;|bb.[0-9].*: *$|[a-z]+:( |$)|$)')
39
40 IR_FUNC_NAME_RE = re.compile(
41     r'^\s*define\s+(?:internal\s+)?[^@]*@(?P<func>[A-Za-z0-9_.]+)\s*\(')
42 IR_PREFIX_DATA_RE = re.compile(r'^ *(;|$)')
43
44 MIR_FUNC_RE = re.compile(
45     r'^---$'
46     r'\n'
47     r'^ *name: *(?P<func>[A-Za-z0-9_.-]+)$'
48     r'.*?'
49     r'^ *body: *\|\n'
50     r'(?P<body>.*?)\n'
51     r'^\.\.\.$',
52     flags=(re.M | re.S))
53
54
55 class LLC:
56     def __init__(self, bin):
57         self.bin = bin
58
59     def __call__(self, args, ir):
60         if ir.endswith('.mir'):
61             args = '{} -x mir'.format(args)
62         with open(ir) as ir_file:
63             stdout = subprocess.check_output('{} {}'.format(self.bin, args),
64                                              shell=True, stdin=ir_file)
65             # Fix line endings to unix CR style.
66             stdout = stdout.replace('\r\n', '\n')
67         return stdout
68
69
70 class Run:
71     def __init__(self, prefixes, cmd_args, triple):
72         self.prefixes = prefixes
73         self.cmd_args = cmd_args
74         self.triple = triple
75
76     def __getitem__(self, index):
77         return [self.prefixes, self.cmd_args, self.triple][index]
78
79
80 def log(msg, verbose=True):
81     if verbose:
82         print(msg, file=sys.stderr)
83
84
85 def warn(msg, test_file=None):
86     if test_file:
87         msg = '{}: {}'.format(test_file, msg)
88     print('WARNING: {}'.format(msg), file=sys.stderr)
89
90
91 def find_triple_in_ir(lines, verbose=False):
92     for l in lines:
93         m = common.TRIPLE_IR_RE.match(l)
94         if m:
95             return m.group(1)
96     return None
97
98
99 def find_run_lines(test, lines, verbose=False):
100     raw_lines = [m.group(1)
101                  for m in [common.RUN_LINE_RE.match(l) for l in lines] if m]
102     run_lines = [raw_lines[0]] if len(raw_lines) > 0 else []
103     for l in raw_lines[1:]:
104         if run_lines[-1].endswith("\\"):
105             run_lines[-1] = run_lines[-1].rstrip("\\") + " " + l
106         else:
107             run_lines.append(l)
108     if verbose:
109         log('Found {} RUN lines:'.format(len(run_lines)))
110         for l in run_lines:
111             log('  RUN: {}'.format(l))
112     return run_lines
113
114
115 def build_run_list(test, run_lines, verbose=False):
116     run_list = []
117     all_prefixes = []
118     for l in run_lines:
119         commands = [cmd.strip() for cmd in l.split('|', 1)]
120         llc_cmd = commands[0]
121         filecheck_cmd = commands[1] if len(commands) > 1 else ''
122
123         if not llc_cmd.startswith('llc '):
124             warn('Skipping non-llc RUN line: {}'.format(l), test_file=test)
125             continue
126         if not filecheck_cmd.startswith('FileCheck '):
127             warn('Skipping non-FileChecked RUN line: {}'.format(l),
128                  test_file=test)
129             continue
130
131         triple = None
132         m = common.TRIPLE_ARG_RE.search(llc_cmd)
133         if m:
134             triple = m.group(1)
135         # If we find -march but not -mtriple, use that.
136         m = common.MARCH_ARG_RE.search(llc_cmd)
137         if m and not triple:
138             triple = '{}--'.format(m.group(1))
139
140         cmd_args = llc_cmd[len('llc'):].strip()
141         cmd_args = cmd_args.replace('< %s', '').replace('%s', '').strip()
142
143         check_prefixes = [
144             item
145             for m in common.CHECK_PREFIX_RE.finditer(filecheck_cmd)
146             for item in m.group(1).split(',')]
147         if not check_prefixes:
148             check_prefixes = ['CHECK']
149         all_prefixes += check_prefixes
150
151         run_list.append(Run(check_prefixes, cmd_args, triple))
152
153     # Remove any common prefixes. We'll just leave those entirely alone.
154     common_prefixes = set([prefix for prefix in all_prefixes
155                            if all_prefixes.count(prefix) > 1])
156     for run in run_list:
157         run.prefixes = [p for p in run.prefixes if p not in common_prefixes]
158
159     return run_list, common_prefixes
160
161
162 def find_functions_with_one_bb(lines, verbose=False):
163     result = []
164     cur_func = None
165     bbs = 0
166     for line in lines:
167         m = MIR_FUNC_NAME_RE.match(line)
168         if m:
169             if bbs == 1:
170                 result.append(cur_func)
171             cur_func = m.group('func')
172             bbs = 0
173         m = MIR_BASIC_BLOCK_RE.match(line)
174         if m:
175             bbs += 1
176     if bbs == 1:
177         result.append(cur_func)
178     return result
179
180
181 def build_function_body_dictionary(test, raw_tool_output, triple, prefixes,
182                                    func_dict, verbose):
183     for m in MIR_FUNC_RE.finditer(raw_tool_output):
184         func = m.group('func')
185         body = m.group('body')
186         if verbose:
187             log('Processing function: {}'.format(func))
188             for l in body.splitlines():
189                 log('  {}'.format(l))
190         for prefix in prefixes:
191             if func in func_dict[prefix] and func_dict[prefix][func] != body:
192                 warn('Found conflicting asm for prefix: {}'.format(prefix),
193                      test_file=test)
194             func_dict[prefix][func] = body
195
196
197 def add_checks_for_function(test, output_lines, run_list, func_dict, func_name,
198                             single_bb, verbose=False):
199     printed_prefixes = set()
200     for run in run_list:
201         for prefix in run.prefixes:
202             if prefix in printed_prefixes:
203                 continue
204             if not func_dict[prefix][func_name]:
205                 continue
206             # if printed_prefixes:
207             #     # Add some space between different check prefixes.
208             #     output_lines.append('')
209             printed_prefixes.add(prefix)
210             log('Adding {} lines for {}'.format(prefix, func_name), verbose)
211             add_check_lines(test, output_lines, prefix, func_name, single_bb,
212                             func_dict[prefix][func_name].splitlines())
213             break
214     return output_lines
215
216
217 def add_check_lines(test, output_lines, prefix, func_name, single_bb,
218                     func_body):
219     if single_bb:
220         # Don't bother checking the basic block label for a single BB
221         func_body.pop(0)
222
223     if not func_body:
224         warn('Function has no instructions to check: {}'.format(func_name),
225              test_file=test)
226         return
227
228     first_line = func_body[0]
229     indent = len(first_line) - len(first_line.lstrip(' '))
230     # A check comment, indented the appropriate amount
231     check = '{:>{}}; {}'.format('', indent, prefix)
232
233     output_lines.append('{}-LABEL: name: {}'.format(check, func_name))
234
235     vreg_map = {}
236     for func_line in func_body:
237         if not func_line.strip():
238             continue
239         m = VREG_DEF_RE.match(func_line)
240         if m:
241             for vreg in VREG_RE.finditer(m.group('vregs')):
242                 name = mangle_vreg(m.group('opcode'), vreg_map.values())
243                 vreg_map[vreg.group(1)] = name
244                 func_line = func_line.replace(
245                     vreg.group(1), '[[{}:%[0-9]+]]'.format(name), 1)
246         for number, name in vreg_map.items():
247             func_line = re.sub(r'{}\b'.format(number), '[[{}]]'.format(name),
248                                func_line)
249         check_line = '{}: {}'.format(check, func_line[indent:]).rstrip()
250         output_lines.append(check_line)
251
252
253 def mangle_vreg(opcode, current_names):
254     base = opcode
255     # Simplify some common prefixes and suffixes
256     if opcode.startswith('G_'):
257         base = base[len('G_'):]
258     if opcode.endswith('_PSEUDO'):
259         base = base[:len('_PSEUDO')]
260     # Shorten some common opcodes with long-ish names
261     base = dict(IMPLICIT_DEF='DEF',
262                 GLOBAL_VALUE='GV',
263                 CONSTANT='C',
264                 FCONSTANT='C',
265                 MERGE_VALUES='MV',
266                 UNMERGE_VALUES='UV',
267                 INTRINSIC='INT',
268                 INTRINSIC_W_SIDE_EFFECTS='INT',
269                 INSERT_VECTOR_ELT='IVEC',
270                 EXTRACT_VECTOR_ELT='EVEC',
271                 SHUFFLE_VECTOR='SHUF').get(base, base)
272     # Avoid ambiguity when opcodes end in numbers
273     if len(base.rstrip('0123456789')) < len(base):
274         base += '_'
275
276     i = 0
277     for name in current_names:
278         if name.rstrip('0123456789') == base:
279             i += 1
280     if i:
281         return '{}{}'.format(base, i)
282     return base
283
284
285 def should_add_line_to_output(input_line, prefix_set):
286     # Skip any check lines that we're handling.
287     m = common.CHECK_RE.match(input_line)
288     if m and m.group(1) in prefix_set:
289         return False
290     return True
291
292
293 def update_test_file(llc, test, remove_common_prefixes=False, verbose=False):
294     log('Scanning for RUN lines in test file: {}'.format(test), verbose)
295     with open(test) as fd:
296         input_lines = [l.rstrip() for l in fd]
297
298     triple_in_ir = find_triple_in_ir(input_lines, verbose)
299     run_lines = find_run_lines(test, input_lines, verbose)
300     run_list, common_prefixes = build_run_list(test, run_lines, verbose)
301
302     simple_functions = find_functions_with_one_bb(input_lines, verbose)
303
304     func_dict = {}
305     for run in run_list:
306         for prefix in run.prefixes:
307             func_dict.update({prefix: dict()})
308     for prefixes, llc_args, triple_in_cmd in run_list:
309         log('Extracted LLC cmd: llc {}'.format(llc_args), verbose)
310         log('Extracted FileCheck prefixes: {}'.format(prefixes), verbose)
311
312         raw_tool_output = llc(llc_args, test)
313         if not triple_in_cmd and not triple_in_ir:
314             warn('No triple found: skipping file', test_file=test)
315             return
316
317         build_function_body_dictionary(test, raw_tool_output,
318                                        triple_in_cmd or triple_in_ir,
319                                        prefixes, func_dict, verbose)
320
321     state = 'toplevel'
322     func_name = None
323     prefix_set = set([prefix for run in run_list for prefix in run.prefixes])
324     log('Rewriting FileCheck prefixes: {}'.format(prefix_set), verbose)
325
326     if remove_common_prefixes:
327         prefix_set.update(common_prefixes)
328     elif common_prefixes:
329         warn('Ignoring common prefixes: {}'.format(common_prefixes),
330              test_file=test)
331
332     comment_char = '#' if test.endswith('.mir') else ';'
333     autogenerated_note = ('{} NOTE: Assertions have been autogenerated by '
334                           'utils/{}'.format(comment_char,
335                                             os.path.basename(__file__)))
336     output_lines = []
337     output_lines.append(autogenerated_note)
338
339     for input_line in input_lines:
340         if input_line == autogenerated_note:
341             continue
342
343         if state == 'toplevel':
344             m = IR_FUNC_NAME_RE.match(input_line)
345             if m:
346                 state = 'ir function prefix'
347                 func_name = m.group('func')
348             if input_line.rstrip('| \r\n') == '---':
349                 state = 'document'
350             output_lines.append(input_line)
351         elif state == 'document':
352             m = MIR_FUNC_NAME_RE.match(input_line)
353             if m:
354                 state = 'mir function metadata'
355                 func_name = m.group('func')
356             if input_line.strip() == '...':
357                 state = 'toplevel'
358                 func_name = None
359             if should_add_line_to_output(input_line, prefix_set):
360                 output_lines.append(input_line)
361         elif state == 'mir function metadata':
362             if should_add_line_to_output(input_line, prefix_set):
363                 output_lines.append(input_line)
364             m = MIR_BODY_BEGIN_RE.match(input_line)
365             if m:
366                 if func_name in simple_functions:
367                     # If there's only one block, put the checks inside it
368                     state = 'mir function prefix'
369                     continue
370                 state = 'mir function body'
371                 add_checks_for_function(test, output_lines, run_list,
372                                         func_dict, func_name, single_bb=False,
373                                         verbose=verbose)
374         elif state == 'mir function prefix':
375             m = MIR_PREFIX_DATA_RE.match(input_line)
376             if not m:
377                 state = 'mir function body'
378                 add_checks_for_function(test, output_lines, run_list,
379                                         func_dict, func_name, single_bb=True,
380                                         verbose=verbose)
381
382             if should_add_line_to_output(input_line, prefix_set):
383                 output_lines.append(input_line)
384         elif state == 'mir function body':
385             if input_line.strip() == '...':
386                 state = 'toplevel'
387                 func_name = None
388             if should_add_line_to_output(input_line, prefix_set):
389                 output_lines.append(input_line)
390         elif state == 'ir function prefix':
391             m = IR_PREFIX_DATA_RE.match(input_line)
392             if not m:
393                 state = 'ir function body'
394                 add_checks_for_function(test, output_lines, run_list,
395                                         func_dict, func_name, single_bb=False,
396                                         verbose=verbose)
397
398             if should_add_line_to_output(input_line, prefix_set):
399                 output_lines.append(input_line)
400         elif state == 'ir function body':
401             if input_line.strip() == '}':
402                 state = 'toplevel'
403                 func_name = None
404             if should_add_line_to_output(input_line, prefix_set):
405                 output_lines.append(input_line)
406
407
408     log('Writing {} lines to {}...'.format(len(output_lines), test), verbose)
409
410     with open(test, 'wb') as fd:
411         fd.writelines([l + '\n' for l in output_lines])
412
413
414 def main():
415     parser = argparse.ArgumentParser(
416         description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
417     parser.add_argument('-v', '--verbose', action='store_true',
418                         help='Show verbose output')
419     parser.add_argument('--llc-binary', dest='llc', default='llc', type=LLC,
420                         help='The "llc" binary to generate the test case with')
421     parser.add_argument('--remove-common-prefixes', action='store_true',
422                         help='Remove existing check lines whose prefixes are '
423                              'shared between multiple commands')
424     parser.add_argument('tests', nargs='+')
425     args = parser.parse_args()
426
427     for test in args.tests:
428         try:
429             update_test_file(args.llc, test, args.remove_common_prefixes,
430                              verbose=args.verbose)
431         except Exception:
432             warn('Error processing file', test_file=test)
433             raise
434
435
436 if __name__ == '__main__':
437   main()