From 9d559ba39dca49c30cdfc81e8fdfbefb06a05f2a Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Wed, 8 Mar 2017 20:23:05 -0800 Subject: [PATCH] nir/constant_expressions: Refactor helper functions Apart from avoiding some unneeded size cases, this shouldn't have any actual functional impact. Reviewed-by: Dylan Baker Reviewed-by: Lionel Landwerlin --- src/compiler/nir/nir_constant_expressions.py | 51 +++++++++++++++------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index 3da20fd503b..c6745f1e934 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -1,16 +1,18 @@ +import re + +type_split_re = re.compile(r'(?P[a-z]+)(?P\d+)') + def type_has_size(type_): return type_[-1:].isdigit() +def type_size(type_): + assert type_has_size(type_) + return int(type_split_re.match(type_).group('bits')) + def type_sizes(type_): - if type_.endswith("8"): - return [8] - elif type_.endswith("16"): - return [16] - elif type_.endswith("32"): - return [32] - elif type_.endswith("64"): - return [64] + if type_has_size(type_): + return [type_size(type_)] else: return [32, 64] @@ -19,23 +21,23 @@ def type_add_size(type_, size): return type_ return type_ + str(size) +def op_bit_sizes(op): + sizes = set([8, 16, 32, 64]) + if not type_has_size(op.output_type): + sizes = sizes.intersection(set(type_sizes(op.output_type))) + for input_type in op.input_types: + if not type_has_size(input_type): + sizes = sizes.intersection(set(type_sizes(input_type))) + return sorted(list(sizes)) + def get_const_field(type_): - if type_ == "int32": - return "i32" - if type_ == "uint32": - return "u32" - if type_ == "int64": - return "i64" - if type_ == "uint64": - return "u64" if type_ == "bool32": return "u32" - if type_ == "float32": - return "f32" - if type_ == "float64": - return "f64" - raise Exception(str(type_)) - assert(0) + else: + m = type_split_re.match(type_) + if not m: + raise Exception(str(type_)) + return m.group('type')[0] + m.group('bits') template = """\ /* @@ -247,7 +249,7 @@ typedef float float32_t; typedef double float64_t; typedef bool bool32_t; % for type in ["float", "int", "uint"]: -% for width in [32, 64]: +% for width in type_sizes(type): struct ${type}${width}_vec { ${type}${width}_t x; ${type}${width}_t y; @@ -272,7 +274,7 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size, nir_const_value _dst_val = { {0, } }; switch (bit_size) { - % for bit_size in [32, 64]: + % for bit_size in op_bit_sizes(op): case ${bit_size}: { <% output_type = type_add_size(op.output_type, bit_size) @@ -406,4 +408,5 @@ from mako.template import Template print Template(template).render(opcodes=opcodes, type_sizes=type_sizes, type_has_size=type_has_size, type_add_size=type_add_size, + op_bit_sizes=op_bit_sizes, get_const_field=get_const_field) -- 2.11.0