2 * Copyright © 2016 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 #include "vtn_private.h"
28 * Normally, column vectors in SPIR-V correspond to a single NIR SSA
29 * definition. But for matrix multiplies, we want to do one routine for
30 * multiplying a matrix by a matrix and then pretend that vectors are matrices
31 * with one column. So we "wrap" these things, and unwrap the result before we
35 static struct vtn_ssa_value *
36 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
41 if (glsl_type_is_matrix(val->type))
44 struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
45 dest->type = val->type;
46 dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
52 static struct vtn_ssa_value *
53 unwrap_matrix(struct vtn_ssa_value *val)
55 if (glsl_type_is_matrix(val->type))
61 static struct vtn_ssa_value *
62 matrix_multiply(struct vtn_builder *b,
63 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
66 struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
67 struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
68 struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
69 struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 unsigned src0_rows = glsl_get_vector_elements(src0->type);
72 unsigned src0_columns = glsl_get_matrix_columns(src0->type);
73 unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 const struct glsl_type *dest_type;
76 if (src1_columns > 1) {
77 dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
78 src0_rows, src1_columns);
80 dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 dest = wrap_matrix(b, dest);
86 bool transpose_result = false;
87 if (src0_transpose && src1_transpose) {
88 /* transpose(A) * transpose(B) = transpose(B * A) */
89 src1 = src0_transpose;
90 src0 = src1_transpose;
91 src0_transpose = NULL;
92 src1_transpose = NULL;
93 transpose_result = true;
96 if (src0_transpose && !src1_transpose &&
97 glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
98 /* We already have the rows of src0 and the columns of src1 available,
99 * so we can just take the dot product of each row with each column to
103 for (unsigned i = 0; i < src1_columns; i++) {
104 nir_ssa_def *vec_src[4];
105 for (unsigned j = 0; j < src0_rows; j++) {
106 vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
107 src1->elems[i]->def);
109 dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
112 /* We don't handle the case where src1 is transposed but not src0, since
113 * the general case only uses individual components of src1 so the
114 * optimizer should chew through the transpose we emitted for src1.
117 for (unsigned i = 0; i < src1_columns; i++) {
118 /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
119 dest->elems[i]->def =
120 nir_fmul(&b->nb, src0->elems[0]->def,
121 nir_channel(&b->nb, src1->elems[i]->def, 0));
122 for (unsigned j = 1; j < src0_columns; j++) {
123 dest->elems[i]->def =
124 nir_fadd(&b->nb, dest->elems[i]->def,
125 nir_fmul(&b->nb, src0->elems[j]->def,
126 nir_channel(&b->nb, src1->elems[i]->def, j)));
131 dest = unwrap_matrix(dest);
133 if (transpose_result)
134 dest = vtn_ssa_transpose(b, dest);
139 static struct vtn_ssa_value *
140 mat_times_scalar(struct vtn_builder *b,
141 struct vtn_ssa_value *mat,
144 struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
145 for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
146 if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
147 dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149 dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
156 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
157 struct vtn_value *dest,
158 struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
162 dest->ssa = vtn_create_ssa_value(b, src0->type);
163 unsigned cols = glsl_get_matrix_columns(src0->type);
164 for (unsigned i = 0; i < cols; i++)
165 dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
170 dest->ssa = vtn_create_ssa_value(b, src0->type);
171 unsigned cols = glsl_get_matrix_columns(src0->type);
172 for (unsigned i = 0; i < cols; i++)
173 dest->ssa->elems[i]->def =
174 nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
179 dest->ssa = vtn_create_ssa_value(b, src0->type);
180 unsigned cols = glsl_get_matrix_columns(src0->type);
181 for (unsigned i = 0; i < cols; i++)
182 dest->ssa->elems[i]->def =
183 nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
188 dest->ssa = vtn_ssa_transpose(b, src0);
191 case SpvOpMatrixTimesScalar:
192 if (src0->transposed) {
193 dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
196 dest->ssa = mat_times_scalar(b, src0, src1->def);
200 case SpvOpVectorTimesMatrix:
201 case SpvOpMatrixTimesVector:
202 case SpvOpMatrixTimesMatrix:
203 if (opcode == SpvOpVectorTimesMatrix) {
204 dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
206 dest->ssa = matrix_multiply(b, src0, src1);
210 default: vtn_fail("unknown matrix opcode");
215 vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
216 struct nir_ssa_def *src)
218 if (glsl_get_vector_elements(dest->type) == src->num_components) {
219 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
221 * "If Result Type has the same number of components as Operand, they
222 * must also have the same component width, and results are computed per
225 dest->def = nir_imov(&b->nb, src);
229 /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
231 * "If Result Type has a different number of components than Operand, the
232 * total number of bits in Result Type must equal the total number of bits
233 * in Operand. Let L be the type, either Result Type or Operand’s type, that
234 * has the larger number of components. Let S be the other type, with the
235 * smaller number of components. The number of components in L must be an
236 * integer multiple of the number of components in S. The first component
237 * (that is, the only or lowest-numbered component) of S maps to the first
238 * components of L, and so on, up to the last component of S mapping to the
239 * last components of L. Within this mapping, any single component of S
240 * (mapping to multiple components of L) maps its lower-ordered bits to the
241 * lower-numbered components of L."
243 unsigned src_bit_size = src->bit_size;
244 unsigned dest_bit_size = glsl_get_bit_size(dest->type);
245 unsigned src_components = src->num_components;
246 unsigned dest_components = glsl_get_vector_elements(dest->type);
247 vtn_assert(src_bit_size * src_components == dest_bit_size * dest_components);
249 nir_ssa_def *dest_chan[4];
250 if (src_bit_size > dest_bit_size) {
251 vtn_assert(src_bit_size % dest_bit_size == 0);
252 unsigned divisor = src_bit_size / dest_bit_size;
253 for (unsigned comp = 0; comp < src_components; comp++) {
254 vtn_assert(src_bit_size == 64);
255 vtn_assert(dest_bit_size == 32);
257 nir_unpack_64_2x32(&b->nb, nir_channel(&b->nb, src, comp));
258 for (unsigned i = 0; i < divisor; i++)
259 dest_chan[divisor * comp + i] = nir_channel(&b->nb, split, i);
262 vtn_assert(dest_bit_size % src_bit_size == 0);
263 unsigned divisor = dest_bit_size / src_bit_size;
264 for (unsigned comp = 0; comp < dest_components; comp++) {
265 unsigned channels = ((1 << divisor) - 1) << (comp * divisor);
266 nir_ssa_def *src_chan =
267 nir_channels(&b->nb, src, channels);
268 vtn_assert(dest_bit_size == 64);
269 vtn_assert(src_bit_size == 32);
270 dest_chan[comp] = nir_pack_64_2x32(&b->nb, src_chan);
273 dest->def = nir_vec(&b->nb, dest_chan, dest_components);
277 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
278 SpvOp opcode, bool *swap,
279 unsigned src_bit_size, unsigned dst_bit_size)
281 /* Indicates that the first two arguments should be swapped. This is
282 * used for implementing greater-than and less-than-or-equal.
287 case SpvOpSNegate: return nir_op_ineg;
288 case SpvOpFNegate: return nir_op_fneg;
289 case SpvOpNot: return nir_op_inot;
290 case SpvOpIAdd: return nir_op_iadd;
291 case SpvOpFAdd: return nir_op_fadd;
292 case SpvOpISub: return nir_op_isub;
293 case SpvOpFSub: return nir_op_fsub;
294 case SpvOpIMul: return nir_op_imul;
295 case SpvOpFMul: return nir_op_fmul;
296 case SpvOpUDiv: return nir_op_udiv;
297 case SpvOpSDiv: return nir_op_idiv;
298 case SpvOpFDiv: return nir_op_fdiv;
299 case SpvOpUMod: return nir_op_umod;
300 case SpvOpSMod: return nir_op_imod;
301 case SpvOpFMod: return nir_op_fmod;
302 case SpvOpSRem: return nir_op_irem;
303 case SpvOpFRem: return nir_op_frem;
305 case SpvOpShiftRightLogical: return nir_op_ushr;
306 case SpvOpShiftRightArithmetic: return nir_op_ishr;
307 case SpvOpShiftLeftLogical: return nir_op_ishl;
308 case SpvOpLogicalOr: return nir_op_ior;
309 case SpvOpLogicalEqual: return nir_op_ieq;
310 case SpvOpLogicalNotEqual: return nir_op_ine;
311 case SpvOpLogicalAnd: return nir_op_iand;
312 case SpvOpLogicalNot: return nir_op_inot;
313 case SpvOpBitwiseOr: return nir_op_ior;
314 case SpvOpBitwiseXor: return nir_op_ixor;
315 case SpvOpBitwiseAnd: return nir_op_iand;
316 case SpvOpSelect: return nir_op_bcsel;
317 case SpvOpIEqual: return nir_op_ieq;
319 case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
320 case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
321 case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
322 case SpvOpBitReverse: return nir_op_bitfield_reverse;
323 case SpvOpBitCount: return nir_op_bit_count;
325 /* The ordered / unordered operators need special implementation besides
326 * the logical operator to use since they also need to check if operands are
329 case SpvOpFOrdEqual: return nir_op_feq;
330 case SpvOpFUnordEqual: return nir_op_feq;
331 case SpvOpINotEqual: return nir_op_ine;
332 case SpvOpFOrdNotEqual: return nir_op_fne;
333 case SpvOpFUnordNotEqual: return nir_op_fne;
334 case SpvOpULessThan: return nir_op_ult;
335 case SpvOpSLessThan: return nir_op_ilt;
336 case SpvOpFOrdLessThan: return nir_op_flt;
337 case SpvOpFUnordLessThan: return nir_op_flt;
338 case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
339 case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
340 case SpvOpFOrdGreaterThan: *swap = true; return nir_op_flt;
341 case SpvOpFUnordGreaterThan: *swap = true; return nir_op_flt;
342 case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
343 case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
344 case SpvOpFOrdLessThanEqual: *swap = true; return nir_op_fge;
345 case SpvOpFUnordLessThanEqual: *swap = true; return nir_op_fge;
346 case SpvOpUGreaterThanEqual: return nir_op_uge;
347 case SpvOpSGreaterThanEqual: return nir_op_ige;
348 case SpvOpFOrdGreaterThanEqual: return nir_op_fge;
349 case SpvOpFUnordGreaterThanEqual: return nir_op_fge;
352 case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
354 case SpvOpConvertFToU:
355 case SpvOpConvertFToS:
356 case SpvOpConvertSToF:
357 case SpvOpConvertUToF:
359 case SpvOpFConvert: {
360 nir_alu_type src_type;
361 nir_alu_type dst_type;
364 case SpvOpConvertFToS:
365 src_type = nir_type_float;
366 dst_type = nir_type_int;
368 case SpvOpConvertFToU:
369 src_type = nir_type_float;
370 dst_type = nir_type_uint;
373 src_type = dst_type = nir_type_float;
375 case SpvOpConvertSToF:
376 src_type = nir_type_int;
377 dst_type = nir_type_float;
380 src_type = dst_type = nir_type_int;
382 case SpvOpConvertUToF:
383 src_type = nir_type_uint;
384 dst_type = nir_type_float;
387 src_type = dst_type = nir_type_uint;
390 unreachable("Invalid opcode");
392 src_type |= src_bit_size;
393 dst_type |= dst_bit_size;
394 return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
397 case SpvOpDPdx: return nir_op_fddx;
398 case SpvOpDPdy: return nir_op_fddy;
399 case SpvOpDPdxFine: return nir_op_fddx_fine;
400 case SpvOpDPdyFine: return nir_op_fddy_fine;
401 case SpvOpDPdxCoarse: return nir_op_fddx_coarse;
402 case SpvOpDPdyCoarse: return nir_op_fddy_coarse;
405 vtn_fail("No NIR equivalent");
410 handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
411 const struct vtn_decoration *dec, void *_void)
413 vtn_assert(dec->scope == VTN_DEC_DECORATION);
414 if (dec->decoration != SpvDecorationNoContraction)
421 handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
422 const struct vtn_decoration *dec, void *_out_rounding_mode)
424 nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
425 assert(dec->scope == VTN_DEC_DECORATION);
426 if (dec->decoration != SpvDecorationFPRoundingMode)
428 switch (dec->literals[0]) {
429 case SpvFPRoundingModeRTE:
430 *out_rounding_mode = nir_rounding_mode_rtne;
432 case SpvFPRoundingModeRTZ:
433 *out_rounding_mode = nir_rounding_mode_rtz;
436 unreachable("Not supported rounding mode");
442 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
443 const uint32_t *w, unsigned count)
445 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
446 const struct glsl_type *type =
447 vtn_value(b, w[1], vtn_value_type_type)->type->type;
449 vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
451 /* Collect the various SSA sources */
452 const unsigned num_inputs = count - 3;
453 struct vtn_ssa_value *vtn_src[4] = { NULL, };
454 for (unsigned i = 0; i < num_inputs; i++)
455 vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
457 if (glsl_type_is_matrix(vtn_src[0]->type) ||
458 (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
459 vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
464 val->ssa = vtn_create_ssa_value(b, type);
465 nir_ssa_def *src[4] = { NULL, };
466 for (unsigned i = 0; i < num_inputs; i++) {
467 vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
468 src[i] = vtn_src[i]->def;
473 if (src[0]->num_components == 1) {
474 val->ssa->def = nir_imov(&b->nb, src[0]);
477 switch (src[0]->num_components) {
478 case 2: op = nir_op_bany_inequal2; break;
479 case 3: op = nir_op_bany_inequal3; break;
480 case 4: op = nir_op_bany_inequal4; break;
481 default: vtn_fail("invalid number of components");
483 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
484 nir_imm_int(&b->nb, NIR_FALSE),
490 if (src[0]->num_components == 1) {
491 val->ssa->def = nir_imov(&b->nb, src[0]);
494 switch (src[0]->num_components) {
495 case 2: op = nir_op_ball_iequal2; break;
496 case 3: op = nir_op_ball_iequal3; break;
497 case 4: op = nir_op_ball_iequal4; break;
498 default: vtn_fail("invalid number of components");
500 val->ssa->def = nir_build_alu(&b->nb, op, src[0],
501 nir_imm_int(&b->nb, NIR_TRUE),
506 case SpvOpOuterProduct: {
507 for (unsigned i = 0; i < src[1]->num_components; i++) {
508 val->ssa->elems[i]->def =
509 nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
515 val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
519 vtn_assert(glsl_type_is_struct(val->ssa->type));
520 val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
521 val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
524 case SpvOpISubBorrow:
525 vtn_assert(glsl_type_is_struct(val->ssa->type));
526 val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
527 val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
530 case SpvOpUMulExtended:
531 vtn_assert(glsl_type_is_struct(val->ssa->type));
532 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
533 val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
536 case SpvOpSMulExtended:
537 vtn_assert(glsl_type_is_struct(val->ssa->type));
538 val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
539 val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
543 val->ssa->def = nir_fadd(&b->nb,
544 nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
545 nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
547 case SpvOpFwidthFine:
548 val->ssa->def = nir_fadd(&b->nb,
549 nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
550 nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
552 case SpvOpFwidthCoarse:
553 val->ssa->def = nir_fadd(&b->nb,
554 nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
555 nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
558 case SpvOpVectorTimesScalar:
559 /* The builder will take care of splatting for us. */
560 val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
564 val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
568 nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
569 val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
573 case SpvOpFUnordEqual:
574 case SpvOpFUnordNotEqual:
575 case SpvOpFUnordLessThan:
576 case SpvOpFUnordGreaterThan:
577 case SpvOpFUnordLessThanEqual:
578 case SpvOpFUnordGreaterThanEqual: {
580 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
581 unsigned dst_bit_size = glsl_get_bit_size(type);
582 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
583 src_bit_size, dst_bit_size);
586 nir_ssa_def *tmp = src[0];
593 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
595 nir_fne(&b->nb, src[0], src[0]),
596 nir_fne(&b->nb, src[1], src[1])));
600 case SpvOpFOrdNotEqual: {
601 /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
602 * from the ALU will probably already be false if the operands are not
603 * ordered so we don’t need to handle it specially.
606 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
607 unsigned dst_bit_size = glsl_get_bit_size(type);
608 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
609 src_bit_size, dst_bit_size);
615 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
617 nir_feq(&b->nb, src[0], src[0]),
618 nir_feq(&b->nb, src[1], src[1])));
623 vtn_handle_bitcast(b, val->ssa, src[0]);
626 case SpvOpFConvert: {
627 nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
628 nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
629 nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
631 vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
632 nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
634 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
638 case SpvOpBitFieldInsert:
639 case SpvOpBitFieldSExtract:
640 case SpvOpBitFieldUExtract:
641 case SpvOpShiftLeftLogical:
642 case SpvOpShiftRightArithmetic:
643 case SpvOpShiftRightLogical: {
645 unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
646 unsigned dst_bit_size = glsl_get_bit_size(type);
647 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
648 src0_bit_size, dst_bit_size);
650 assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
651 op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
652 op == nir_op_ibitfield_extract);
654 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
655 unsigned src_bit_size =
656 nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
657 if (src_bit_size == 0)
659 if (src_bit_size != src[i]->bit_size) {
660 assert(src_bit_size == 32);
661 /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
662 * supported by the NIR instructions. See discussion here:
664 * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
666 src[i] = nir_u2u32(&b->nb, src[i]);
669 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
675 unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
676 unsigned dst_bit_size = glsl_get_bit_size(type);
677 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
678 src_bit_size, dst_bit_size);
681 nir_ssa_def *tmp = src[0];
686 val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);