OSDN Git Service

spirv: convert some operands for bitwise shift and bitwise ops to uint32
[android-x86/external-mesa.git] / src / compiler / spirv / vtn_alu.c
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23
24 #include <math.h>
25 #include "vtn_private.h"
26
27 /*
28  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
29  * definition. But for matrix multiplies, we want to do one routine for
30  * multiplying a matrix by a matrix and then pretend that vectors are matrices
31  * with one column. So we "wrap" these things, and unwrap the result before we
32  * send it off.
33  */
34
35 static struct vtn_ssa_value *
36 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
37 {
38    if (val == NULL)
39       return NULL;
40
41    if (glsl_type_is_matrix(val->type))
42       return val;
43
44    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
45    dest->type = val->type;
46    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
47    dest->elems[0] = val;
48
49    return dest;
50 }
51
52 static struct vtn_ssa_value *
53 unwrap_matrix(struct vtn_ssa_value *val)
54 {
55    if (glsl_type_is_matrix(val->type))
56          return val;
57
58    return val->elems[0];
59 }
60
61 static struct vtn_ssa_value *
62 matrix_multiply(struct vtn_builder *b,
63                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
64 {
65
66    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
67    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
68    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
69    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
70
71    unsigned src0_rows = glsl_get_vector_elements(src0->type);
72    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
73    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
74
75    const struct glsl_type *dest_type;
76    if (src1_columns > 1) {
77       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
78                                    src0_rows, src1_columns);
79    } else {
80       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
81    }
82    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
83
84    dest = wrap_matrix(b, dest);
85
86    bool transpose_result = false;
87    if (src0_transpose && src1_transpose) {
88       /* transpose(A) * transpose(B) = transpose(B * A) */
89       src1 = src0_transpose;
90       src0 = src1_transpose;
91       src0_transpose = NULL;
92       src1_transpose = NULL;
93       transpose_result = true;
94    }
95
96    if (src0_transpose && !src1_transpose &&
97        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
98       /* We already have the rows of src0 and the columns of src1 available,
99        * so we can just take the dot product of each row with each column to
100        * get the result.
101        */
102
103       for (unsigned i = 0; i < src1_columns; i++) {
104          nir_ssa_def *vec_src[4];
105          for (unsigned j = 0; j < src0_rows; j++) {
106             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
107                                           src1->elems[i]->def);
108          }
109          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
110       }
111    } else {
112       /* We don't handle the case where src1 is transposed but not src0, since
113        * the general case only uses individual components of src1 so the
114        * optimizer should chew through the transpose we emitted for src1.
115        */
116
117       for (unsigned i = 0; i < src1_columns; i++) {
118          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
119          dest->elems[i]->def =
120             nir_fmul(&b->nb, src0->elems[0]->def,
121                      nir_channel(&b->nb, src1->elems[i]->def, 0));
122          for (unsigned j = 1; j < src0_columns; j++) {
123             dest->elems[i]->def =
124                nir_fadd(&b->nb, dest->elems[i]->def,
125                         nir_fmul(&b->nb, src0->elems[j]->def,
126                                  nir_channel(&b->nb, src1->elems[i]->def, j)));
127          }
128       }
129    }
130
131    dest = unwrap_matrix(dest);
132
133    if (transpose_result)
134       dest = vtn_ssa_transpose(b, dest);
135
136    return dest;
137 }
138
139 static struct vtn_ssa_value *
140 mat_times_scalar(struct vtn_builder *b,
141                  struct vtn_ssa_value *mat,
142                  nir_ssa_def *scalar)
143 {
144    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
145    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
146       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
147          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
148       else
149          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
150    }
151
152    return dest;
153 }
154
155 static void
156 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
157                       struct vtn_value *dest,
158                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160    switch (opcode) {
161    case SpvOpFNegate: {
162       dest->ssa = vtn_create_ssa_value(b, src0->type);
163       unsigned cols = glsl_get_matrix_columns(src0->type);
164       for (unsigned i = 0; i < cols; i++)
165          dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166       break;
167    }
168
169    case SpvOpFAdd: {
170       dest->ssa = vtn_create_ssa_value(b, src0->type);
171       unsigned cols = glsl_get_matrix_columns(src0->type);
172       for (unsigned i = 0; i < cols; i++)
173          dest->ssa->elems[i]->def =
174             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175       break;
176    }
177
178    case SpvOpFSub: {
179       dest->ssa = vtn_create_ssa_value(b, src0->type);
180       unsigned cols = glsl_get_matrix_columns(src0->type);
181       for (unsigned i = 0; i < cols; i++)
182          dest->ssa->elems[i]->def =
183             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184       break;
185    }
186
187    case SpvOpTranspose:
188       dest->ssa = vtn_ssa_transpose(b, src0);
189       break;
190
191    case SpvOpMatrixTimesScalar:
192       if (src0->transposed) {
193          dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
194                                                            src1->def));
195       } else {
196          dest->ssa = mat_times_scalar(b, src0, src1->def);
197       }
198       break;
199
200    case SpvOpVectorTimesMatrix:
201    case SpvOpMatrixTimesVector:
202    case SpvOpMatrixTimesMatrix:
203       if (opcode == SpvOpVectorTimesMatrix) {
204          dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
205       } else {
206          dest->ssa = matrix_multiply(b, src0, src1);
207       }
208       break;
209
210    default: vtn_fail("unknown matrix opcode");
211    }
212 }
213
214 static void
215 vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
216                    struct nir_ssa_def *src)
217 {
218    if (glsl_get_vector_elements(dest->type) == src->num_components) {
219       /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
220        *
221        * "If Result Type has the same number of components as Operand, they
222        * must also have the same component width, and results are computed per
223        * component."
224        */
225       dest->def = nir_imov(&b->nb, src);
226       return;
227    }
228
229    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
230     *
231     * "If Result Type has a different number of components than Operand, the
232     * total number of bits in Result Type must equal the total number of bits
233     * in Operand. Let L be the type, either Result Type or Operand’s type, that
234     * has the larger number of components. Let S be the other type, with the
235     * smaller number of components. The number of components in L must be an
236     * integer multiple of the number of components in S. The first component
237     * (that is, the only or lowest-numbered component) of S maps to the first
238     * components of L, and so on, up to the last component of S mapping to the
239     * last components of L. Within this mapping, any single component of S
240     * (mapping to multiple components of L) maps its lower-ordered bits to the
241     * lower-numbered components of L."
242     */
243    unsigned src_bit_size = src->bit_size;
244    unsigned dest_bit_size = glsl_get_bit_size(dest->type);
245    unsigned src_components = src->num_components;
246    unsigned dest_components = glsl_get_vector_elements(dest->type);
247    vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components);
248
249    nir_ssa_def *dest_chan[4];
250    if (src_bit_size > dest_bit_size) {
251       vtn_assert(src_bit_size % dest_bit_size == 0);
252       unsigned divisor = src_bit_size / dest_bit_size;
253       for (unsigned comp = 0; comp < src_components; comp++) {
254          vtn_assert(src_bit_size == 64);
255          vtn_assert(dest_bit_size == 32);
256          nir_ssa_def *split =
257             nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp));
258          for (unsigned i = 0; i < divisor; i++)
259             dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i);
260       }
261    } else {
262       vtn_assert(dest_bit_size % src_bit_size == 0);
263       unsigned divisor = dest_bit_size / src_bit_size;
264       for (unsigned comp = 0; comp < dest_components; comp++) {
265          unsigned channels = ((1 << divisor) - 1) << (comp * divisor);
266          nir_ssa_def *src_chan =
267             nir_channels(&b->nb, src, channels);
268          vtn_assert(dest_bit_size == 64);
269          vtn_assert(src_bit_size == 32);
270          dest_chan[comp] = nir_pack_64_2x32(&b->nb, src_chan);
271       }
272    }
273    dest->def = nir_vec(&b->nb, dest_chan, dest_components);
274 }
275
276 nir_op
277 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
278                                 SpvOp opcode, bool *swap,
279                                 unsigned src_bit_size, unsigned dst_bit_size)
280 {
281    /* Indicates that the first two arguments should be swapped.  This is
282     * used for implementing greater-than and less-than-or-equal.
283     */
284    *swap = false;
285
286    switch (opcode) {
287    case SpvOpSNegate:            return nir_op_ineg;
288    case SpvOpFNegate:            return nir_op_fneg;
289    case SpvOpNot:                return nir_op_inot;
290    case SpvOpIAdd:               return nir_op_iadd;
291    case SpvOpFAdd:               return nir_op_fadd;
292    case SpvOpISub:               return nir_op_isub;
293    case SpvOpFSub:               return nir_op_fsub;
294    case SpvOpIMul:               return nir_op_imul;
295    case SpvOpFMul:               return nir_op_fmul;
296    case SpvOpUDiv:               return nir_op_udiv;
297    case SpvOpSDiv:               return nir_op_idiv;
298    case SpvOpFDiv:               return nir_op_fdiv;
299    case SpvOpUMod:               return nir_op_umod;
300    case SpvOpSMod:               return nir_op_imod;
301    case SpvOpFMod:               return nir_op_fmod;
302    case SpvOpSRem:               return nir_op_irem;
303    case SpvOpFRem:               return nir_op_frem;
304
305    case SpvOpShiftRightLogical:     return nir_op_ushr;
306    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
307    case SpvOpShiftLeftLogical:      return nir_op_ishl;
308    case SpvOpLogicalOr:             return nir_op_ior;
309    case SpvOpLogicalEqual:          return nir_op_ieq;
310    case SpvOpLogicalNotEqual:       return nir_op_ine;
311    case SpvOpLogicalAnd:            return nir_op_iand;
312    case SpvOpLogicalNot:            return nir_op_inot;
313    case SpvOpBitwiseOr:             return nir_op_ior;
314    case SpvOpBitwiseXor:            return nir_op_ixor;
315    case SpvOpBitwiseAnd:            return nir_op_iand;
316    case SpvOpSelect:                return nir_op_bcsel;
317    case SpvOpIEqual:                return nir_op_ieq;
318
319    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
320    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
321    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
322    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
323    case SpvOpBitCount:              return nir_op_bit_count;
324
325    /* The ordered / unordered operators need special implementation besides
326     * the logical operator to use since they also need to check if operands are
327     * ordered.
328     */
329    case SpvOpFOrdEqual:                            return nir_op_feq;
330    case SpvOpFUnordEqual:                          return nir_op_feq;
331    case SpvOpINotEqual:                            return nir_op_ine;
332    case SpvOpFOrdNotEqual:                         return nir_op_fne;
333    case SpvOpFUnordNotEqual:                       return nir_op_fne;
334    case SpvOpULessThan:                            return nir_op_ult;
335    case SpvOpSLessThan:                            return nir_op_ilt;
336    case SpvOpFOrdLessThan:                         return nir_op_flt;
337    case SpvOpFUnordLessThan:                       return nir_op_flt;
338    case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
339    case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
340    case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
341    case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
342    case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
343    case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
344    case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
345    case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
346    case SpvOpUGreaterThanEqual:                    return nir_op_uge;
347    case SpvOpSGreaterThanEqual:                    return nir_op_ige;
348    case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
349    case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
350
351    /* Conversions: */
352    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
353    case SpvOpUConvert:
354    case SpvOpConvertFToU:
355    case SpvOpConvertFToS:
356    case SpvOpConvertSToF:
357    case SpvOpConvertUToF:
358    case SpvOpSConvert:
359    case SpvOpFConvert: {
360       nir_alu_type src_type;
361       nir_alu_type dst_type;
362
363       switch (opcode) {
364       case SpvOpConvertFToS:
365          src_type = nir_type_float;
366          dst_type = nir_type_int;
367          break;
368       case SpvOpConvertFToU:
369          src_type = nir_type_float;
370          dst_type = nir_type_uint;
371          break;
372       case SpvOpFConvert:
373          src_type = dst_type = nir_type_float;
374          break;
375       case SpvOpConvertSToF:
376          src_type = nir_type_int;
377          dst_type = nir_type_float;
378          break;
379       case SpvOpSConvert:
380          src_type = dst_type = nir_type_int;
381          break;
382       case SpvOpConvertUToF:
383          src_type = nir_type_uint;
384          dst_type = nir_type_float;
385          break;
386       case SpvOpUConvert:
387          src_type = dst_type = nir_type_uint;
388          break;
389       default:
390          unreachable("Invalid opcode");
391       }
392       src_type |= src_bit_size;
393       dst_type |= dst_bit_size;
394       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
395    }
396    /* Derivatives: */
397    case SpvOpDPdx:         return nir_op_fddx;
398    case SpvOpDPdy:         return nir_op_fddy;
399    case SpvOpDPdxFine:     return nir_op_fddx_fine;
400    case SpvOpDPdyFine:     return nir_op_fddy_fine;
401    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
402    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
403
404    default:
405       vtn_fail("No NIR equivalent");
406    }
407 }
408
409 static void
410 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
411                       const struct vtn_decoration *dec, void *_void)
412 {
413    vtn_assert(dec->scope == VTN_DEC_DECORATION);
414    if (dec->decoration != SpvDecorationNoContraction)
415       return;
416
417    b->nb.exact = true;
418 }
419
420 static void
421 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
422                      const struct vtn_decoration *dec, void *_out_rounding_mode)
423 {
424    nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
425    assert(dec->scope == VTN_DEC_DECORATION);
426    if (dec->decoration != SpvDecorationFPRoundingMode)
427       return;
428    switch (dec->literals[0]) {
429    case SpvFPRoundingModeRTE:
430       *out_rounding_mode = nir_rounding_mode_rtne;
431       break;
432    case SpvFPRoundingModeRTZ:
433       *out_rounding_mode = nir_rounding_mode_rtz;
434       break;
435    default:
436       unreachable("Not supported rounding mode");
437       break;
438    }
439 }
440
441 void
442 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
443                const uint32_t *w, unsigned count)
444 {
445    struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
446    const struct glsl_type *type =
447       vtn_value(b, w[1], vtn_value_type_type)->type->type;
448
449    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
450
451    /* Collect the various SSA sources */
452    const unsigned num_inputs = count - 3;
453    struct vtn_ssa_value *vtn_src[4] = { NULL, };
454    for (unsigned i = 0; i < num_inputs; i++)
455       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
456
457    if (glsl_type_is_matrix(vtn_src[0]->type) ||
458        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
459       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
460       b->nb.exact = false;
461       return;
462    }
463
464    val->ssa = vtn_create_ssa_value(b, type);
465    nir_ssa_def *src[4] = { NULL, };
466    for (unsigned i = 0; i < num_inputs; i++) {
467       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
468       src[i] = vtn_src[i]->def;
469    }
470
471    switch (opcode) {
472    case SpvOpAny:
473       if (src[0]->num_components == 1) {
474          val->ssa->def = nir_imov(&b->nb, src[0]);
475       } else {
476          nir_op op;
477          switch (src[0]->num_components) {
478          case 2:  op = nir_op_bany_inequal2; break;
479          case 3:  op = nir_op_bany_inequal3; break;
480          case 4:  op = nir_op_bany_inequal4; break;
481          default: vtn_fail("invalid number of components");
482          }
483          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
484                                        nir_imm_int(&b->nb, NIR_FALSE),
485                                        NULL, NULL);
486       }
487       break;
488
489    case SpvOpAll:
490       if (src[0]->num_components == 1) {
491          val->ssa->def = nir_imov(&b->nb, src[0]);
492       } else {
493          nir_op op;
494          switch (src[0]->num_components) {
495          case 2:  op = nir_op_ball_iequal2;  break;
496          case 3:  op = nir_op_ball_iequal3;  break;
497          case 4:  op = nir_op_ball_iequal4;  break;
498          default: vtn_fail("invalid number of components");
499          }
500          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
501                                        nir_imm_int(&b->nb, NIR_TRUE),
502                                        NULL, NULL);
503       }
504       break;
505
506    case SpvOpOuterProduct: {
507       for (unsigned i = 0; i < src[1]->num_components; i++) {
508          val->ssa->elems[i]->def =
509             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
510       }
511       break;
512    }
513
514    case SpvOpDot:
515       val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
516       break;
517
518    case SpvOpIAddCarry:
519       vtn_assert(glsl_type_is_struct(val->ssa->type));
520       val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
521       val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
522       break;
523
524    case SpvOpISubBorrow:
525       vtn_assert(glsl_type_is_struct(val->ssa->type));
526       val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
527       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
528       break;
529
530    case SpvOpUMulExtended:
531       vtn_assert(glsl_type_is_struct(val->ssa->type));
532       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
533       val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
534       break;
535
536    case SpvOpSMulExtended:
537       vtn_assert(glsl_type_is_struct(val->ssa->type));
538       val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
539       val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
540       break;
541
542    case SpvOpFwidth:
543       val->ssa->def = nir_fadd(&b->nb,
544                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
545                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
546       break;
547    case SpvOpFwidthFine:
548       val->ssa->def = nir_fadd(&b->nb,
549                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
550                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
551       break;
552    case SpvOpFwidthCoarse:
553       val->ssa->def = nir_fadd(&b->nb,
554                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
555                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
556       break;
557
558    case SpvOpVectorTimesScalar:
559       /* The builder will take care of splatting for us. */
560       val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
561       break;
562
563    case SpvOpIsNan:
564       val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
565       break;
566
567    case SpvOpIsInf: {
568       nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
569       val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
570       break;
571    }
572
573    case SpvOpFUnordEqual:
574    case SpvOpFUnordNotEqual:
575    case SpvOpFUnordLessThan:
576    case SpvOpFUnordGreaterThan:
577    case SpvOpFUnordLessThanEqual:
578    case SpvOpFUnordGreaterThanEqual: {
579       bool swap;
580       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
581       unsigned dst_bit_size = glsl_get_bit_size(type);
582       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
583                                                   src_bit_size, dst_bit_size);
584
585       if (swap) {
586          nir_ssa_def *tmp = src[0];
587          src[0] = src[1];
588          src[1] = tmp;
589       }
590
591       val->ssa->def =
592          nir_ior(&b->nb,
593                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
594                  nir_ior(&b->nb,
595                          nir_fne(&b->nb, src[0], src[0]),
596                          nir_fne(&b->nb, src[1], src[1])));
597       break;
598    }
599
600    case SpvOpFOrdNotEqual: {
601       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
602        * from the ALU will probably already be false if the operands are not
603        * ordered so we don’t need to handle it specially.
604        */
605       bool swap;
606       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
607       unsigned dst_bit_size = glsl_get_bit_size(type);
608       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
609                                                   src_bit_size, dst_bit_size);
610
611       assert(!swap);
612
613       val->ssa->def =
614          nir_iand(&b->nb,
615                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
616                   nir_iand(&b->nb,
617                           nir_feq(&b->nb, src[0], src[0]),
618                           nir_feq(&b->nb, src[1], src[1])));
619       break;
620    }
621
622    case SpvOpBitcast:
623       vtn_handle_bitcast(b, val->ssa, src[0]);
624       break;
625
626    case SpvOpFConvert: {
627       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
628       nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
629       nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
630
631       vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
632       nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
633
634       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
635       break;
636    }
637
638    case SpvOpBitFieldInsert:
639    case SpvOpBitFieldSExtract:
640    case SpvOpBitFieldUExtract:
641    case SpvOpShiftLeftLogical:
642    case SpvOpShiftRightArithmetic:
643    case SpvOpShiftRightLogical: {
644       bool swap;
645       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
646       unsigned dst_bit_size = glsl_get_bit_size(type);
647       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
648                                                   src0_bit_size, dst_bit_size);
649
650       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
651               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
652               op == nir_op_ibitfield_extract);
653
654       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
655          unsigned src_bit_size =
656             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
657          if (src_bit_size == 0)
658             continue;
659          if (src_bit_size != src[i]->bit_size) {
660             assert(src_bit_size == 32);
661             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
662              * supported by the NIR instructions. See discussion here:
663              *
664              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
665              */
666             src[i] = nir_u2u32(&b->nb, src[i]);
667          }
668       }
669       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
670       break;
671    }
672
673    default: {
674       bool swap;
675       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
676       unsigned dst_bit_size = glsl_get_bit_size(type);
677       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
678                                                   src_bit_size, dst_bit_size);
679
680       if (swap) {
681          nir_ssa_def *tmp = src[0];
682          src[0] = src[1];
683          src[1] = tmp;
684       }
685
686       val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
687       break;
688    } /* default */
689    }
690
691    b->nb.exact = false;
692 }