OSDN Git Service

libvhost-user: return on error in vu_log_queue_fill()
[qmiga/qemu.git] / contrib / libvhost-user / libvhost-user.c
1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <aliguori@us.ibm.com>
9  *  Marc-AndrĂ© Lureau <mlureau@redhat.com>
10  *  Victor Kaplansky <victork@redhat.com>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15
16 /* this code avoids GLib dependency */
17 #include <stdlib.h>
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <stdarg.h>
21 #include <errno.h>
22 #include <string.h>
23 #include <assert.h>
24 #include <inttypes.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/eventfd.h>
28 #include <sys/mman.h>
29 #include "qemu/compiler.h"
30
31 #if defined(__linux__)
32 #include <sys/syscall.h>
33 #include <fcntl.h>
34 #include <sys/ioctl.h>
35 #include <linux/vhost.h>
36
37 #ifdef __NR_userfaultfd
38 #include <linux/userfaultfd.h>
39 #endif
40
41 #endif
42
43 #include "qemu/atomic.h"
44 #include "qemu/osdep.h"
45 #include "qemu/bswap.h"
46 #include "qemu/memfd.h"
47
48 #include "libvhost-user.h"
49
50 /* usually provided by GLib */
51 #ifndef MIN
52 #define MIN(x, y) ({                            \
53             typeof(x) _min1 = (x);              \
54             typeof(y) _min2 = (y);              \
55             (void) (&_min1 == &_min2);          \
56             _min1 < _min2 ? _min1 : _min2; })
57 #endif
58
59 /* Round number down to multiple */
60 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
61
62 /* Round number up to multiple */
63 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
64
65 /* Align each region to cache line size in inflight buffer */
66 #define INFLIGHT_ALIGNMENT 64
67
68 /* The version of inflight buffer */
69 #define INFLIGHT_VERSION 1
70
71 #define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
72
73 /* The version of the protocol we support */
74 #define VHOST_USER_VERSION 1
75 #define LIBVHOST_USER_DEBUG 0
76
77 #define DPRINT(...)                             \
78     do {                                        \
79         if (LIBVHOST_USER_DEBUG) {              \
80             fprintf(stderr, __VA_ARGS__);        \
81         }                                       \
82     } while (0)
83
84 static inline
85 bool has_feature(uint64_t features, unsigned int fbit)
86 {
87     assert(fbit < 64);
88     return !!(features & (1ULL << fbit));
89 }
90
91 static inline
92 bool vu_has_feature(VuDev *dev,
93                     unsigned int fbit)
94 {
95     return has_feature(dev->features, fbit);
96 }
97
98 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
99 {
100     return has_feature(dev->protocol_features, fbit);
101 }
102
103 static const char *
104 vu_request_to_string(unsigned int req)
105 {
106 #define REQ(req) [req] = #req
107     static const char *vu_request_str[] = {
108         REQ(VHOST_USER_NONE),
109         REQ(VHOST_USER_GET_FEATURES),
110         REQ(VHOST_USER_SET_FEATURES),
111         REQ(VHOST_USER_SET_OWNER),
112         REQ(VHOST_USER_RESET_OWNER),
113         REQ(VHOST_USER_SET_MEM_TABLE),
114         REQ(VHOST_USER_SET_LOG_BASE),
115         REQ(VHOST_USER_SET_LOG_FD),
116         REQ(VHOST_USER_SET_VRING_NUM),
117         REQ(VHOST_USER_SET_VRING_ADDR),
118         REQ(VHOST_USER_SET_VRING_BASE),
119         REQ(VHOST_USER_GET_VRING_BASE),
120         REQ(VHOST_USER_SET_VRING_KICK),
121         REQ(VHOST_USER_SET_VRING_CALL),
122         REQ(VHOST_USER_SET_VRING_ERR),
123         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
124         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
125         REQ(VHOST_USER_GET_QUEUE_NUM),
126         REQ(VHOST_USER_SET_VRING_ENABLE),
127         REQ(VHOST_USER_SEND_RARP),
128         REQ(VHOST_USER_NET_SET_MTU),
129         REQ(VHOST_USER_SET_SLAVE_REQ_FD),
130         REQ(VHOST_USER_IOTLB_MSG),
131         REQ(VHOST_USER_SET_VRING_ENDIAN),
132         REQ(VHOST_USER_GET_CONFIG),
133         REQ(VHOST_USER_SET_CONFIG),
134         REQ(VHOST_USER_POSTCOPY_ADVISE),
135         REQ(VHOST_USER_POSTCOPY_LISTEN),
136         REQ(VHOST_USER_POSTCOPY_END),
137         REQ(VHOST_USER_GET_INFLIGHT_FD),
138         REQ(VHOST_USER_SET_INFLIGHT_FD),
139         REQ(VHOST_USER_GPU_SET_SOCKET),
140         REQ(VHOST_USER_VRING_KICK),
141         REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
142         REQ(VHOST_USER_ADD_MEM_REG),
143         REQ(VHOST_USER_REM_MEM_REG),
144         REQ(VHOST_USER_MAX),
145     };
146 #undef REQ
147
148     if (req < VHOST_USER_MAX) {
149         return vu_request_str[req];
150     } else {
151         return "unknown";
152     }
153 }
154
155 static void
156 vu_panic(VuDev *dev, const char *msg, ...)
157 {
158     char *buf = NULL;
159     va_list ap;
160
161     va_start(ap, msg);
162     if (vasprintf(&buf, msg, ap) < 0) {
163         buf = NULL;
164     }
165     va_end(ap);
166
167     dev->broken = true;
168     dev->panic(dev, buf);
169     free(buf);
170
171     /*
172      * FIXME:
173      * find a way to call virtio_error, or perhaps close the connection?
174      */
175 }
176
177 /* Translate guest physical address to our virtual address.  */
178 void *
179 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
180 {
181     int i;
182
183     if (*plen == 0) {
184         return NULL;
185     }
186
187     /* Find matching memory region.  */
188     for (i = 0; i < dev->nregions; i++) {
189         VuDevRegion *r = &dev->regions[i];
190
191         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
192             if ((guest_addr + *plen) > (r->gpa + r->size)) {
193                 *plen = r->gpa + r->size - guest_addr;
194             }
195             return (void *)(uintptr_t)
196                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
197         }
198     }
199
200     return NULL;
201 }
202
203 /* Translate qemu virtual address to our virtual address.  */
204 static void *
205 qva_to_va(VuDev *dev, uint64_t qemu_addr)
206 {
207     int i;
208
209     /* Find matching memory region.  */
210     for (i = 0; i < dev->nregions; i++) {
211         VuDevRegion *r = &dev->regions[i];
212
213         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
214             return (void *)(uintptr_t)
215                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
216         }
217     }
218
219     return NULL;
220 }
221
222 static void
223 vmsg_close_fds(VhostUserMsg *vmsg)
224 {
225     int i;
226
227     for (i = 0; i < vmsg->fd_num; i++) {
228         close(vmsg->fds[i]);
229     }
230 }
231
232 /* Set reply payload.u64 and clear request flags and fd_num */
233 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
234 {
235     vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
236     vmsg->size = sizeof(vmsg->payload.u64);
237     vmsg->payload.u64 = val;
238     vmsg->fd_num = 0;
239 }
240
241 /* A test to see if we have userfault available */
242 static bool
243 have_userfault(void)
244 {
245 #if defined(__linux__) && defined(__NR_userfaultfd) &&\
246         defined(UFFD_FEATURE_MISSING_SHMEM) &&\
247         defined(UFFD_FEATURE_MISSING_HUGETLBFS)
248     /* Now test the kernel we're running on really has the features */
249     int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
250     struct uffdio_api api_struct;
251     if (ufd < 0) {
252         return false;
253     }
254
255     api_struct.api = UFFD_API;
256     api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
257                           UFFD_FEATURE_MISSING_HUGETLBFS;
258     if (ioctl(ufd, UFFDIO_API, &api_struct)) {
259         close(ufd);
260         return false;
261     }
262     close(ufd);
263     return true;
264
265 #else
266     return false;
267 #endif
268 }
269
270 static bool
271 vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
272 {
273     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
274     struct iovec iov = {
275         .iov_base = (char *)vmsg,
276         .iov_len = VHOST_USER_HDR_SIZE,
277     };
278     struct msghdr msg = {
279         .msg_iov = &iov,
280         .msg_iovlen = 1,
281         .msg_control = control,
282         .msg_controllen = sizeof(control),
283     };
284     size_t fd_size;
285     struct cmsghdr *cmsg;
286     int rc;
287
288     do {
289         rc = recvmsg(conn_fd, &msg, 0);
290     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
291
292     if (rc < 0) {
293         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
294         return false;
295     }
296
297     vmsg->fd_num = 0;
298     for (cmsg = CMSG_FIRSTHDR(&msg);
299          cmsg != NULL;
300          cmsg = CMSG_NXTHDR(&msg, cmsg))
301     {
302         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
303             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
304             vmsg->fd_num = fd_size / sizeof(int);
305             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
306             break;
307         }
308     }
309
310     if (vmsg->size > sizeof(vmsg->payload)) {
311         vu_panic(dev,
312                  "Error: too big message request: %d, size: vmsg->size: %u, "
313                  "while sizeof(vmsg->payload) = %zu\n",
314                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
315         goto fail;
316     }
317
318     if (vmsg->size) {
319         do {
320             rc = read(conn_fd, &vmsg->payload, vmsg->size);
321         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
322
323         if (rc <= 0) {
324             vu_panic(dev, "Error while reading: %s", strerror(errno));
325             goto fail;
326         }
327
328         assert(rc == vmsg->size);
329     }
330
331     return true;
332
333 fail:
334     vmsg_close_fds(vmsg);
335
336     return false;
337 }
338
339 static bool
340 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
341 {
342     int rc;
343     uint8_t *p = (uint8_t *)vmsg;
344     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
345     struct iovec iov = {
346         .iov_base = (char *)vmsg,
347         .iov_len = VHOST_USER_HDR_SIZE,
348     };
349     struct msghdr msg = {
350         .msg_iov = &iov,
351         .msg_iovlen = 1,
352         .msg_control = control,
353     };
354     struct cmsghdr *cmsg;
355
356     memset(control, 0, sizeof(control));
357     assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
358     if (vmsg->fd_num > 0) {
359         size_t fdsize = vmsg->fd_num * sizeof(int);
360         msg.msg_controllen = CMSG_SPACE(fdsize);
361         cmsg = CMSG_FIRSTHDR(&msg);
362         cmsg->cmsg_len = CMSG_LEN(fdsize);
363         cmsg->cmsg_level = SOL_SOCKET;
364         cmsg->cmsg_type = SCM_RIGHTS;
365         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
366     } else {
367         msg.msg_controllen = 0;
368     }
369
370     do {
371         rc = sendmsg(conn_fd, &msg, 0);
372     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
373
374     if (vmsg->size) {
375         do {
376             if (vmsg->data) {
377                 rc = write(conn_fd, vmsg->data, vmsg->size);
378             } else {
379                 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
380             }
381         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
382     }
383
384     if (rc <= 0) {
385         vu_panic(dev, "Error while writing: %s", strerror(errno));
386         return false;
387     }
388
389     return true;
390 }
391
392 static bool
393 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
394 {
395     /* Set the version in the flags when sending the reply */
396     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
397     vmsg->flags |= VHOST_USER_VERSION;
398     vmsg->flags |= VHOST_USER_REPLY_MASK;
399
400     return vu_message_write(dev, conn_fd, vmsg);
401 }
402
403 /*
404  * Processes a reply on the slave channel.
405  * Entered with slave_mutex held and releases it before exit.
406  * Returns true on success.
407  */
408 static bool
409 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
410 {
411     VhostUserMsg msg_reply;
412     bool result = false;
413
414     if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
415         result = true;
416         goto out;
417     }
418
419     if (!vu_message_read(dev, dev->slave_fd, &msg_reply)) {
420         goto out;
421     }
422
423     if (msg_reply.request != vmsg->request) {
424         DPRINT("Received unexpected msg type. Expected %d received %d",
425                vmsg->request, msg_reply.request);
426         goto out;
427     }
428
429     result = msg_reply.payload.u64 == 0;
430
431 out:
432     pthread_mutex_unlock(&dev->slave_mutex);
433     return result;
434 }
435
436 /* Kick the log_call_fd if required. */
437 static void
438 vu_log_kick(VuDev *dev)
439 {
440     if (dev->log_call_fd != -1) {
441         DPRINT("Kicking the QEMU's log...\n");
442         if (eventfd_write(dev->log_call_fd, 1) < 0) {
443             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
444         }
445     }
446 }
447
448 static void
449 vu_log_page(uint8_t *log_table, uint64_t page)
450 {
451     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
452     qatomic_or(&log_table[page / 8], 1 << (page % 8));
453 }
454
455 static void
456 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
457 {
458     uint64_t page;
459
460     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
461         !dev->log_table || !length) {
462         return;
463     }
464
465     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
466
467     page = address / VHOST_LOG_PAGE;
468     while (page * VHOST_LOG_PAGE < address + length) {
469         vu_log_page(dev->log_table, page);
470         page += 1;
471     }
472
473     vu_log_kick(dev);
474 }
475
476 static void
477 vu_kick_cb(VuDev *dev, int condition, void *data)
478 {
479     int index = (intptr_t)data;
480     VuVirtq *vq = &dev->vq[index];
481     int sock = vq->kick_fd;
482     eventfd_t kick_data;
483     ssize_t rc;
484
485     rc = eventfd_read(sock, &kick_data);
486     if (rc == -1) {
487         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
488         dev->remove_watch(dev, dev->vq[index].kick_fd);
489     } else {
490         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
491                kick_data, vq->handler, index);
492         if (vq->handler) {
493             vq->handler(dev, index);
494         }
495     }
496 }
497
498 static bool
499 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
500 {
501     vmsg->payload.u64 =
502         /*
503          * The following VIRTIO feature bits are supported by our virtqueue
504          * implementation:
505          */
506         1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
507         1ULL << VIRTIO_RING_F_INDIRECT_DESC |
508         1ULL << VIRTIO_RING_F_EVENT_IDX |
509         1ULL << VIRTIO_F_VERSION_1 |
510
511         /* vhost-user feature bits */
512         1ULL << VHOST_F_LOG_ALL |
513         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
514
515     if (dev->iface->get_features) {
516         vmsg->payload.u64 |= dev->iface->get_features(dev);
517     }
518
519     vmsg->size = sizeof(vmsg->payload.u64);
520     vmsg->fd_num = 0;
521
522     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
523
524     return true;
525 }
526
527 static void
528 vu_set_enable_all_rings(VuDev *dev, bool enabled)
529 {
530     uint16_t i;
531
532     for (i = 0; i < dev->max_queues; i++) {
533         dev->vq[i].enable = enabled;
534     }
535 }
536
537 static bool
538 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
539 {
540     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
541
542     dev->features = vmsg->payload.u64;
543     if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
544         /*
545          * We only support devices conforming to VIRTIO 1.0 or
546          * later
547          */
548         vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
549         return false;
550     }
551
552     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
553         vu_set_enable_all_rings(dev, true);
554     }
555
556     if (dev->iface->set_features) {
557         dev->iface->set_features(dev, dev->features);
558     }
559
560     return false;
561 }
562
563 static bool
564 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
565 {
566     return false;
567 }
568
569 static void
570 vu_close_log(VuDev *dev)
571 {
572     if (dev->log_table) {
573         if (munmap(dev->log_table, dev->log_size) != 0) {
574             perror("close log munmap() error");
575         }
576
577         dev->log_table = NULL;
578     }
579     if (dev->log_call_fd != -1) {
580         close(dev->log_call_fd);
581         dev->log_call_fd = -1;
582     }
583 }
584
585 static bool
586 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
587 {
588     vu_set_enable_all_rings(dev, false);
589
590     return false;
591 }
592
593 static bool
594 map_ring(VuDev *dev, VuVirtq *vq)
595 {
596     vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
597     vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
598     vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
599
600     DPRINT("Setting virtq addresses:\n");
601     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
602     DPRINT("    vring_used  at %p\n", vq->vring.used);
603     DPRINT("    vring_avail at %p\n", vq->vring.avail);
604
605     return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
606 }
607
608 static bool
609 generate_faults(VuDev *dev) {
610     int i;
611     for (i = 0; i < dev->nregions; i++) {
612         VuDevRegion *dev_region = &dev->regions[i];
613         int ret;
614 #ifdef UFFDIO_REGISTER
615         /*
616          * We should already have an open ufd. Mark each memory
617          * range as ufd.
618          * Discard any mapping we have here; note I can't use MADV_REMOVE
619          * or fallocate to make the hole since I don't want to lose
620          * data that's already arrived in the shared process.
621          * TODO: How to do hugepage
622          */
623         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
624                       dev_region->size + dev_region->mmap_offset,
625                       MADV_DONTNEED);
626         if (ret) {
627             fprintf(stderr,
628                     "%s: Failed to madvise(DONTNEED) region %d: %s\n",
629                     __func__, i, strerror(errno));
630         }
631         /*
632          * Turn off transparent hugepages so we dont get lose wakeups
633          * in neighbouring pages.
634          * TODO: Turn this backon later.
635          */
636         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
637                       dev_region->size + dev_region->mmap_offset,
638                       MADV_NOHUGEPAGE);
639         if (ret) {
640             /*
641              * Note: This can happen legally on kernels that are configured
642              * without madvise'able hugepages
643              */
644             fprintf(stderr,
645                     "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
646                     __func__, i, strerror(errno));
647         }
648         struct uffdio_register reg_struct;
649         reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
650         reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
651         reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
652
653         if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
654             vu_panic(dev, "%s: Failed to userfault region %d "
655                           "@%p + size:%zx offset: %zx: (ufd=%d)%s\n",
656                      __func__, i,
657                      dev_region->mmap_addr,
658                      dev_region->size, dev_region->mmap_offset,
659                      dev->postcopy_ufd, strerror(errno));
660             return false;
661         }
662         if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
663             vu_panic(dev, "%s Region (%d) doesn't support COPY",
664                      __func__, i);
665             return false;
666         }
667         DPRINT("%s: region %d: Registered userfault for %"
668                PRIx64 " + %" PRIx64 "\n", __func__, i,
669                (uint64_t)reg_struct.range.start,
670                (uint64_t)reg_struct.range.len);
671         /* Now it's registered we can let the client at it */
672         if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
673                      dev_region->size + dev_region->mmap_offset,
674                      PROT_READ | PROT_WRITE)) {
675             vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
676                      i, strerror(errno));
677             return false;
678         }
679         /* TODO: Stash 'zero' support flags somewhere */
680 #endif
681     }
682
683     return true;
684 }
685
686 static bool
687 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
688     int i;
689     bool track_ramblocks = dev->postcopy_listening;
690     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
691     VuDevRegion *dev_region = &dev->regions[dev->nregions];
692     void *mmap_addr;
693
694     /*
695      * If we are in postcopy mode and we receive a u64 payload with a 0 value
696      * we know all the postcopy client bases have been received, and we
697      * should start generating faults.
698      */
699     if (track_ramblocks &&
700         vmsg->size == sizeof(vmsg->payload.u64) &&
701         vmsg->payload.u64 == 0) {
702         (void)generate_faults(dev);
703         return false;
704     }
705
706     DPRINT("Adding region: %d\n", dev->nregions);
707     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
708            msg_region->guest_phys_addr);
709     DPRINT("    memory_size:     0x%016"PRIx64"\n",
710            msg_region->memory_size);
711     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
712            msg_region->userspace_addr);
713     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
714            msg_region->mmap_offset);
715
716     dev_region->gpa = msg_region->guest_phys_addr;
717     dev_region->size = msg_region->memory_size;
718     dev_region->qva = msg_region->userspace_addr;
719     dev_region->mmap_offset = msg_region->mmap_offset;
720
721     /*
722      * We don't use offset argument of mmap() since the
723      * mapped address has to be page aligned, and we use huge
724      * pages.
725      */
726     if (track_ramblocks) {
727         /*
728          * In postcopy we're using PROT_NONE here to catch anyone
729          * accessing it before we userfault.
730          */
731         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
732                          PROT_NONE, MAP_SHARED,
733                          vmsg->fds[0], 0);
734     } else {
735         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
736                          PROT_READ | PROT_WRITE, MAP_SHARED, vmsg->fds[0],
737                          0);
738     }
739
740     if (mmap_addr == MAP_FAILED) {
741         vu_panic(dev, "region mmap error: %s", strerror(errno));
742     } else {
743         dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
744         DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
745                dev_region->mmap_addr);
746     }
747
748     close(vmsg->fds[0]);
749
750     if (track_ramblocks) {
751         /*
752          * Return the address to QEMU so that it can translate the ufd
753          * fault addresses back.
754          */
755         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
756                                                  dev_region->mmap_offset);
757
758         /* Send the message back to qemu with the addresses filled in. */
759         vmsg->fd_num = 0;
760         if (!vu_send_reply(dev, dev->sock, vmsg)) {
761             vu_panic(dev, "failed to respond to add-mem-region for postcopy");
762             return false;
763         }
764
765         DPRINT("Successfully added new region in postcopy\n");
766         dev->nregions++;
767         return false;
768
769     } else {
770         for (i = 0; i < dev->max_queues; i++) {
771             if (dev->vq[i].vring.desc) {
772                 if (map_ring(dev, &dev->vq[i])) {
773                     vu_panic(dev, "remapping queue %d for new memory region",
774                              i);
775                 }
776             }
777         }
778
779         DPRINT("Successfully added new region\n");
780         dev->nregions++;
781         vmsg_set_reply_u64(vmsg, 0);
782         return true;
783     }
784 }
785
786 static inline bool reg_equal(VuDevRegion *vudev_reg,
787                              VhostUserMemoryRegion *msg_reg)
788 {
789     if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
790         vudev_reg->qva == msg_reg->userspace_addr &&
791         vudev_reg->size == msg_reg->memory_size) {
792         return true;
793     }
794
795     return false;
796 }
797
798 static bool
799 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
800     int i, j;
801     bool found = false;
802     VuDevRegion shadow_regions[VHOST_USER_MAX_RAM_SLOTS] = {};
803     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
804
805     DPRINT("Removing region:\n");
806     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
807            msg_region->guest_phys_addr);
808     DPRINT("    memory_size:     0x%016"PRIx64"\n",
809            msg_region->memory_size);
810     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
811            msg_region->userspace_addr);
812     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
813            msg_region->mmap_offset);
814
815     for (i = 0, j = 0; i < dev->nregions; i++) {
816         if (!reg_equal(&dev->regions[i], msg_region)) {
817             shadow_regions[j].gpa = dev->regions[i].gpa;
818             shadow_regions[j].size = dev->regions[i].size;
819             shadow_regions[j].qva = dev->regions[i].qva;
820             shadow_regions[j].mmap_offset = dev->regions[i].mmap_offset;
821             j++;
822         } else {
823             found = true;
824             VuDevRegion *r = &dev->regions[i];
825             void *m = (void *) (uintptr_t) r->mmap_addr;
826
827             if (m) {
828                 munmap(m, r->size + r->mmap_offset);
829             }
830         }
831     }
832
833     if (found) {
834         memcpy(dev->regions, shadow_regions,
835                sizeof(VuDevRegion) * VHOST_USER_MAX_RAM_SLOTS);
836         DPRINT("Successfully removed a region\n");
837         dev->nregions--;
838         vmsg_set_reply_u64(vmsg, 0);
839     } else {
840         vu_panic(dev, "Specified region not found\n");
841     }
842
843     return true;
844 }
845
846 static bool
847 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
848 {
849     int i;
850     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
851     dev->nregions = memory->nregions;
852
853     DPRINT("Nregions: %d\n", memory->nregions);
854     for (i = 0; i < dev->nregions; i++) {
855         void *mmap_addr;
856         VhostUserMemoryRegion *msg_region = &memory->regions[i];
857         VuDevRegion *dev_region = &dev->regions[i];
858
859         DPRINT("Region %d\n", i);
860         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
861                msg_region->guest_phys_addr);
862         DPRINT("    memory_size:     0x%016"PRIx64"\n",
863                msg_region->memory_size);
864         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
865                msg_region->userspace_addr);
866         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
867                msg_region->mmap_offset);
868
869         dev_region->gpa = msg_region->guest_phys_addr;
870         dev_region->size = msg_region->memory_size;
871         dev_region->qva = msg_region->userspace_addr;
872         dev_region->mmap_offset = msg_region->mmap_offset;
873
874         /* We don't use offset argument of mmap() since the
875          * mapped address has to be page aligned, and we use huge
876          * pages.
877          * In postcopy we're using PROT_NONE here to catch anyone
878          * accessing it before we userfault
879          */
880         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
881                          PROT_NONE, MAP_SHARED,
882                          vmsg->fds[i], 0);
883
884         if (mmap_addr == MAP_FAILED) {
885             vu_panic(dev, "region mmap error: %s", strerror(errno));
886         } else {
887             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
888             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
889                    dev_region->mmap_addr);
890         }
891
892         /* Return the address to QEMU so that it can translate the ufd
893          * fault addresses back.
894          */
895         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
896                                                  dev_region->mmap_offset);
897         close(vmsg->fds[i]);
898     }
899
900     /* Send the message back to qemu with the addresses filled in */
901     vmsg->fd_num = 0;
902     if (!vu_send_reply(dev, dev->sock, vmsg)) {
903         vu_panic(dev, "failed to respond to set-mem-table for postcopy");
904         return false;
905     }
906
907     /* Wait for QEMU to confirm that it's registered the handler for the
908      * faults.
909      */
910     if (!vu_message_read(dev, dev->sock, vmsg) ||
911         vmsg->size != sizeof(vmsg->payload.u64) ||
912         vmsg->payload.u64 != 0) {
913         vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
914         return false;
915     }
916
917     /* OK, now we can go and register the memory and generate faults */
918     (void)generate_faults(dev);
919
920     return false;
921 }
922
923 static bool
924 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
925 {
926     int i;
927     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
928
929     for (i = 0; i < dev->nregions; i++) {
930         VuDevRegion *r = &dev->regions[i];
931         void *m = (void *) (uintptr_t) r->mmap_addr;
932
933         if (m) {
934             munmap(m, r->size + r->mmap_offset);
935         }
936     }
937     dev->nregions = memory->nregions;
938
939     if (dev->postcopy_listening) {
940         return vu_set_mem_table_exec_postcopy(dev, vmsg);
941     }
942
943     DPRINT("Nregions: %d\n", memory->nregions);
944     for (i = 0; i < dev->nregions; i++) {
945         void *mmap_addr;
946         VhostUserMemoryRegion *msg_region = &memory->regions[i];
947         VuDevRegion *dev_region = &dev->regions[i];
948
949         DPRINT("Region %d\n", i);
950         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
951                msg_region->guest_phys_addr);
952         DPRINT("    memory_size:     0x%016"PRIx64"\n",
953                msg_region->memory_size);
954         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
955                msg_region->userspace_addr);
956         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
957                msg_region->mmap_offset);
958
959         dev_region->gpa = msg_region->guest_phys_addr;
960         dev_region->size = msg_region->memory_size;
961         dev_region->qva = msg_region->userspace_addr;
962         dev_region->mmap_offset = msg_region->mmap_offset;
963
964         /* We don't use offset argument of mmap() since the
965          * mapped address has to be page aligned, and we use huge
966          * pages.  */
967         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
968                          PROT_READ | PROT_WRITE, MAP_SHARED,
969                          vmsg->fds[i], 0);
970
971         if (mmap_addr == MAP_FAILED) {
972             vu_panic(dev, "region mmap error: %s", strerror(errno));
973         } else {
974             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
975             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
976                    dev_region->mmap_addr);
977         }
978
979         close(vmsg->fds[i]);
980     }
981
982     for (i = 0; i < dev->max_queues; i++) {
983         if (dev->vq[i].vring.desc) {
984             if (map_ring(dev, &dev->vq[i])) {
985                 vu_panic(dev, "remapping queue %d during setmemtable", i);
986             }
987         }
988     }
989
990     return false;
991 }
992
993 static bool
994 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
995 {
996     int fd;
997     uint64_t log_mmap_size, log_mmap_offset;
998     void *rc;
999
1000     if (vmsg->fd_num != 1 ||
1001         vmsg->size != sizeof(vmsg->payload.log)) {
1002         vu_panic(dev, "Invalid log_base message");
1003         return true;
1004     }
1005
1006     fd = vmsg->fds[0];
1007     log_mmap_offset = vmsg->payload.log.mmap_offset;
1008     log_mmap_size = vmsg->payload.log.mmap_size;
1009     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1010     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1011
1012     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1013               log_mmap_offset);
1014     close(fd);
1015     if (rc == MAP_FAILED) {
1016         perror("log mmap error");
1017     }
1018
1019     if (dev->log_table) {
1020         munmap(dev->log_table, dev->log_size);
1021     }
1022     dev->log_table = rc;
1023     dev->log_size = log_mmap_size;
1024
1025     vmsg->size = sizeof(vmsg->payload.u64);
1026     vmsg->fd_num = 0;
1027
1028     return true;
1029 }
1030
1031 static bool
1032 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1033 {
1034     if (vmsg->fd_num != 1) {
1035         vu_panic(dev, "Invalid log_fd message");
1036         return false;
1037     }
1038
1039     if (dev->log_call_fd != -1) {
1040         close(dev->log_call_fd);
1041     }
1042     dev->log_call_fd = vmsg->fds[0];
1043     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1044
1045     return false;
1046 }
1047
1048 static bool
1049 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1050 {
1051     unsigned int index = vmsg->payload.state.index;
1052     unsigned int num = vmsg->payload.state.num;
1053
1054     DPRINT("State.index: %d\n", index);
1055     DPRINT("State.num:   %d\n", num);
1056     dev->vq[index].vring.num = num;
1057
1058     return false;
1059 }
1060
1061 static bool
1062 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1063 {
1064     struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1065     unsigned int index = vra->index;
1066     VuVirtq *vq = &dev->vq[index];
1067
1068     DPRINT("vhost_vring_addr:\n");
1069     DPRINT("    index:  %d\n", vra->index);
1070     DPRINT("    flags:  %d\n", vra->flags);
1071     DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", vra->desc_user_addr);
1072     DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", vra->used_user_addr);
1073     DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", vra->avail_user_addr);
1074     DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", vra->log_guest_addr);
1075
1076     vq->vra = *vra;
1077     vq->vring.flags = vra->flags;
1078     vq->vring.log_guest_addr = vra->log_guest_addr;
1079
1080
1081     if (map_ring(dev, vq)) {
1082         vu_panic(dev, "Invalid vring_addr message");
1083         return false;
1084     }
1085
1086     vq->used_idx = lduw_le_p(&vq->vring.used->idx);
1087
1088     if (vq->last_avail_idx != vq->used_idx) {
1089         bool resume = dev->iface->queue_is_processed_in_order &&
1090             dev->iface->queue_is_processed_in_order(dev, index);
1091
1092         DPRINT("Last avail index != used index: %u != %u%s\n",
1093                vq->last_avail_idx, vq->used_idx,
1094                resume ? ", resuming" : "");
1095
1096         if (resume) {
1097             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1098         }
1099     }
1100
1101     return false;
1102 }
1103
1104 static bool
1105 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1106 {
1107     unsigned int index = vmsg->payload.state.index;
1108     unsigned int num = vmsg->payload.state.num;
1109
1110     DPRINT("State.index: %d\n", index);
1111     DPRINT("State.num:   %d\n", num);
1112     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1113
1114     return false;
1115 }
1116
1117 static bool
1118 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1119 {
1120     unsigned int index = vmsg->payload.state.index;
1121
1122     DPRINT("State.index: %d\n", index);
1123     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1124     vmsg->size = sizeof(vmsg->payload.state);
1125
1126     dev->vq[index].started = false;
1127     if (dev->iface->queue_set_started) {
1128         dev->iface->queue_set_started(dev, index, false);
1129     }
1130
1131     if (dev->vq[index].call_fd != -1) {
1132         close(dev->vq[index].call_fd);
1133         dev->vq[index].call_fd = -1;
1134     }
1135     if (dev->vq[index].kick_fd != -1) {
1136         dev->remove_watch(dev, dev->vq[index].kick_fd);
1137         close(dev->vq[index].kick_fd);
1138         dev->vq[index].kick_fd = -1;
1139     }
1140
1141     return true;
1142 }
1143
1144 static bool
1145 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1146 {
1147     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1148     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1149
1150     if (index >= dev->max_queues) {
1151         vmsg_close_fds(vmsg);
1152         vu_panic(dev, "Invalid queue index: %u", index);
1153         return false;
1154     }
1155
1156     if (nofd) {
1157         vmsg_close_fds(vmsg);
1158         return true;
1159     }
1160
1161     if (vmsg->fd_num != 1) {
1162         vmsg_close_fds(vmsg);
1163         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1164         return false;
1165     }
1166
1167     return true;
1168 }
1169
1170 static int
1171 inflight_desc_compare(const void *a, const void *b)
1172 {
1173     VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1174                         *desc1 = (VuVirtqInflightDesc *)b;
1175
1176     if (desc1->counter > desc0->counter &&
1177         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1178         return 1;
1179     }
1180
1181     return -1;
1182 }
1183
1184 static int
1185 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1186 {
1187     int i = 0;
1188
1189     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1190         return 0;
1191     }
1192
1193     if (unlikely(!vq->inflight)) {
1194         return -1;
1195     }
1196
1197     if (unlikely(!vq->inflight->version)) {
1198         /* initialize the buffer */
1199         vq->inflight->version = INFLIGHT_VERSION;
1200         return 0;
1201     }
1202
1203     vq->used_idx = lduw_le_p(&vq->vring.used->idx);
1204     vq->resubmit_num = 0;
1205     vq->resubmit_list = NULL;
1206     vq->counter = 0;
1207
1208     if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1209         vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1210
1211         barrier();
1212
1213         vq->inflight->used_idx = vq->used_idx;
1214     }
1215
1216     for (i = 0; i < vq->inflight->desc_num; i++) {
1217         if (vq->inflight->desc[i].inflight == 1) {
1218             vq->inuse++;
1219         }
1220     }
1221
1222     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1223
1224     if (vq->inuse) {
1225         vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1226         if (!vq->resubmit_list) {
1227             return -1;
1228         }
1229
1230         for (i = 0; i < vq->inflight->desc_num; i++) {
1231             if (vq->inflight->desc[i].inflight) {
1232                 vq->resubmit_list[vq->resubmit_num].index = i;
1233                 vq->resubmit_list[vq->resubmit_num].counter =
1234                                         vq->inflight->desc[i].counter;
1235                 vq->resubmit_num++;
1236             }
1237         }
1238
1239         if (vq->resubmit_num > 1) {
1240             qsort(vq->resubmit_list, vq->resubmit_num,
1241                   sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1242         }
1243         vq->counter = vq->resubmit_list[0].counter + 1;
1244     }
1245
1246     /* in case of I/O hang after reconnecting */
1247     if (eventfd_write(vq->kick_fd, 1)) {
1248         return -1;
1249     }
1250
1251     return 0;
1252 }
1253
1254 static bool
1255 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1256 {
1257     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1258     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1259
1260     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1261
1262     if (!vu_check_queue_msg_file(dev, vmsg)) {
1263         return false;
1264     }
1265
1266     if (dev->vq[index].kick_fd != -1) {
1267         dev->remove_watch(dev, dev->vq[index].kick_fd);
1268         close(dev->vq[index].kick_fd);
1269         dev->vq[index].kick_fd = -1;
1270     }
1271
1272     dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1273     DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1274
1275     dev->vq[index].started = true;
1276     if (dev->iface->queue_set_started) {
1277         dev->iface->queue_set_started(dev, index, true);
1278     }
1279
1280     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1281         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1282                        vu_kick_cb, (void *)(long)index);
1283
1284         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1285                dev->vq[index].kick_fd, index);
1286     }
1287
1288     if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1289         vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1290     }
1291
1292     return false;
1293 }
1294
1295 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1296                           vu_queue_handler_cb handler)
1297 {
1298     int qidx = vq - dev->vq;
1299
1300     vq->handler = handler;
1301     if (vq->kick_fd >= 0) {
1302         if (handler) {
1303             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1304                            vu_kick_cb, (void *)(long)qidx);
1305         } else {
1306             dev->remove_watch(dev, vq->kick_fd);
1307         }
1308     }
1309 }
1310
1311 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1312                                 int size, int offset)
1313 {
1314     int qidx = vq - dev->vq;
1315     int fd_num = 0;
1316     VhostUserMsg vmsg = {
1317         .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1318         .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1319         .size = sizeof(vmsg.payload.area),
1320         .payload.area = {
1321             .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1322             .size = size,
1323             .offset = offset,
1324         },
1325     };
1326
1327     if (fd == -1) {
1328         vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1329     } else {
1330         vmsg.fds[fd_num++] = fd;
1331     }
1332
1333     vmsg.fd_num = fd_num;
1334
1335     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1336         return false;
1337     }
1338
1339     pthread_mutex_lock(&dev->slave_mutex);
1340     if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1341         pthread_mutex_unlock(&dev->slave_mutex);
1342         return false;
1343     }
1344
1345     /* Also unlocks the slave_mutex */
1346     return vu_process_message_reply(dev, &vmsg);
1347 }
1348
1349 static bool
1350 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1351 {
1352     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1353     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1354
1355     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1356
1357     if (!vu_check_queue_msg_file(dev, vmsg)) {
1358         return false;
1359     }
1360
1361     if (dev->vq[index].call_fd != -1) {
1362         close(dev->vq[index].call_fd);
1363         dev->vq[index].call_fd = -1;
1364     }
1365
1366     dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1367
1368     /* in case of I/O hang after reconnecting */
1369     if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1370         return -1;
1371     }
1372
1373     DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1374
1375     return false;
1376 }
1377
1378 static bool
1379 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1380 {
1381     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1382     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1383
1384     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1385
1386     if (!vu_check_queue_msg_file(dev, vmsg)) {
1387         return false;
1388     }
1389
1390     if (dev->vq[index].err_fd != -1) {
1391         close(dev->vq[index].err_fd);
1392         dev->vq[index].err_fd = -1;
1393     }
1394
1395     dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1396
1397     return false;
1398 }
1399
1400 static bool
1401 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1402 {
1403     /*
1404      * Note that we support, but intentionally do not set,
1405      * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1406      * a device implementation can return it in its callback
1407      * (get_protocol_features) if it wants to use this for
1408      * simulation, but it is otherwise not desirable (if even
1409      * implemented by the master.)
1410      */
1411     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1412                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1413                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1414                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1415                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1416                         1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1417                         1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1418
1419     if (have_userfault()) {
1420         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1421     }
1422
1423     if (dev->iface->get_config && dev->iface->set_config) {
1424         features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1425     }
1426
1427     if (dev->iface->get_protocol_features) {
1428         features |= dev->iface->get_protocol_features(dev);
1429     }
1430
1431     vmsg_set_reply_u64(vmsg, features);
1432     return true;
1433 }
1434
1435 static bool
1436 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1437 {
1438     uint64_t features = vmsg->payload.u64;
1439
1440     DPRINT("u64: 0x%016"PRIx64"\n", features);
1441
1442     dev->protocol_features = vmsg->payload.u64;
1443
1444     if (vu_has_protocol_feature(dev,
1445                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1446         (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1447          !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1448         /*
1449          * The use case for using messages for kick/call is simulation, to make
1450          * the kick and call synchronous. To actually get that behaviour, both
1451          * of the other features are required.
1452          * Theoretically, one could use only kick messages, or do them without
1453          * having F_REPLY_ACK, but too many (possibly pending) messages on the
1454          * socket will eventually cause the master to hang, to avoid this in
1455          * scenarios where not desired enforce that the settings are in a way
1456          * that actually enables the simulation case.
1457          */
1458         vu_panic(dev,
1459                  "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1460         return false;
1461     }
1462
1463     if (dev->iface->set_protocol_features) {
1464         dev->iface->set_protocol_features(dev, features);
1465     }
1466
1467     return false;
1468 }
1469
1470 static bool
1471 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1472 {
1473     vmsg_set_reply_u64(vmsg, dev->max_queues);
1474     return true;
1475 }
1476
1477 static bool
1478 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1479 {
1480     unsigned int index = vmsg->payload.state.index;
1481     unsigned int enable = vmsg->payload.state.num;
1482
1483     DPRINT("State.index: %d\n", index);
1484     DPRINT("State.enable:   %d\n", enable);
1485
1486     if (index >= dev->max_queues) {
1487         vu_panic(dev, "Invalid vring_enable index: %u", index);
1488         return false;
1489     }
1490
1491     dev->vq[index].enable = enable;
1492     return false;
1493 }
1494
1495 static bool
1496 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1497 {
1498     if (vmsg->fd_num != 1) {
1499         vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1500         return false;
1501     }
1502
1503     if (dev->slave_fd != -1) {
1504         close(dev->slave_fd);
1505     }
1506     dev->slave_fd = vmsg->fds[0];
1507     DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1508
1509     return false;
1510 }
1511
1512 static bool
1513 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1514 {
1515     int ret = -1;
1516
1517     if (dev->iface->get_config) {
1518         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1519                                      vmsg->payload.config.size);
1520     }
1521
1522     if (ret) {
1523         /* resize to zero to indicate an error to master */
1524         vmsg->size = 0;
1525     }
1526
1527     return true;
1528 }
1529
1530 static bool
1531 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1532 {
1533     int ret = -1;
1534
1535     if (dev->iface->set_config) {
1536         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1537                                      vmsg->payload.config.offset,
1538                                      vmsg->payload.config.size,
1539                                      vmsg->payload.config.flags);
1540         if (ret) {
1541             vu_panic(dev, "Set virtio configuration space failed");
1542         }
1543     }
1544
1545     return false;
1546 }
1547
1548 static bool
1549 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1550 {
1551     dev->postcopy_ufd = -1;
1552 #ifdef UFFDIO_API
1553     struct uffdio_api api_struct;
1554
1555     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1556     vmsg->size = 0;
1557 #endif
1558
1559     if (dev->postcopy_ufd == -1) {
1560         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1561         goto out;
1562     }
1563
1564 #ifdef UFFDIO_API
1565     api_struct.api = UFFD_API;
1566     api_struct.features = 0;
1567     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1568         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1569         close(dev->postcopy_ufd);
1570         dev->postcopy_ufd = -1;
1571         goto out;
1572     }
1573     /* TODO: Stash feature flags somewhere */
1574 #endif
1575
1576 out:
1577     /* Return a ufd to the QEMU */
1578     vmsg->fd_num = 1;
1579     vmsg->fds[0] = dev->postcopy_ufd;
1580     return true; /* = send a reply */
1581 }
1582
1583 static bool
1584 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1585 {
1586     if (dev->nregions) {
1587         vu_panic(dev, "Regions already registered at postcopy-listen");
1588         vmsg_set_reply_u64(vmsg, -1);
1589         return true;
1590     }
1591     dev->postcopy_listening = true;
1592
1593     vmsg_set_reply_u64(vmsg, 0);
1594     return true;
1595 }
1596
1597 static bool
1598 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1599 {
1600     DPRINT("%s: Entry\n", __func__);
1601     dev->postcopy_listening = false;
1602     if (dev->postcopy_ufd > 0) {
1603         close(dev->postcopy_ufd);
1604         dev->postcopy_ufd = -1;
1605         DPRINT("%s: Done close\n", __func__);
1606     }
1607
1608     vmsg_set_reply_u64(vmsg, 0);
1609     DPRINT("%s: exit\n", __func__);
1610     return true;
1611 }
1612
1613 static inline uint64_t
1614 vu_inflight_queue_size(uint16_t queue_size)
1615 {
1616     return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1617            sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1618 }
1619
1620 static bool
1621 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1622 {
1623     int fd;
1624     void *addr;
1625     uint64_t mmap_size;
1626     uint16_t num_queues, queue_size;
1627
1628     if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1629         vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1630         vmsg->payload.inflight.mmap_size = 0;
1631         return true;
1632     }
1633
1634     num_queues = vmsg->payload.inflight.num_queues;
1635     queue_size = vmsg->payload.inflight.queue_size;
1636
1637     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1638     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1639
1640     mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1641
1642     addr = qemu_memfd_alloc("vhost-inflight", mmap_size,
1643                             F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1644                             &fd, NULL);
1645
1646     if (!addr) {
1647         vu_panic(dev, "Failed to alloc vhost inflight area");
1648         vmsg->payload.inflight.mmap_size = 0;
1649         return true;
1650     }
1651
1652     memset(addr, 0, mmap_size);
1653
1654     dev->inflight_info.addr = addr;
1655     dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1656     dev->inflight_info.fd = vmsg->fds[0] = fd;
1657     vmsg->fd_num = 1;
1658     vmsg->payload.inflight.mmap_offset = 0;
1659
1660     DPRINT("send inflight mmap_size: %"PRId64"\n",
1661            vmsg->payload.inflight.mmap_size);
1662     DPRINT("send inflight mmap offset: %"PRId64"\n",
1663            vmsg->payload.inflight.mmap_offset);
1664
1665     return true;
1666 }
1667
1668 static bool
1669 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1670 {
1671     int fd, i;
1672     uint64_t mmap_size, mmap_offset;
1673     uint16_t num_queues, queue_size;
1674     void *rc;
1675
1676     if (vmsg->fd_num != 1 ||
1677         vmsg->size != sizeof(vmsg->payload.inflight)) {
1678         vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1679                  vmsg->size, vmsg->fd_num);
1680         return false;
1681     }
1682
1683     fd = vmsg->fds[0];
1684     mmap_size = vmsg->payload.inflight.mmap_size;
1685     mmap_offset = vmsg->payload.inflight.mmap_offset;
1686     num_queues = vmsg->payload.inflight.num_queues;
1687     queue_size = vmsg->payload.inflight.queue_size;
1688
1689     DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1690     DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1691     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1692     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1693
1694     rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1695               fd, mmap_offset);
1696
1697     if (rc == MAP_FAILED) {
1698         vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1699         return false;
1700     }
1701
1702     if (dev->inflight_info.fd) {
1703         close(dev->inflight_info.fd);
1704     }
1705
1706     if (dev->inflight_info.addr) {
1707         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1708     }
1709
1710     dev->inflight_info.fd = fd;
1711     dev->inflight_info.addr = rc;
1712     dev->inflight_info.size = mmap_size;
1713
1714     for (i = 0; i < num_queues; i++) {
1715         dev->vq[i].inflight = (VuVirtqInflight *)rc;
1716         dev->vq[i].inflight->desc_num = queue_size;
1717         rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1718     }
1719
1720     return false;
1721 }
1722
1723 static bool
1724 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1725 {
1726     unsigned int index = vmsg->payload.state.index;
1727
1728     if (index >= dev->max_queues) {
1729         vu_panic(dev, "Invalid queue index: %u", index);
1730         return false;
1731     }
1732
1733     DPRINT("Got kick message: handler:%p idx:%d\n",
1734            dev->vq[index].handler, index);
1735
1736     if (!dev->vq[index].started) {
1737         dev->vq[index].started = true;
1738
1739         if (dev->iface->queue_set_started) {
1740             dev->iface->queue_set_started(dev, index, true);
1741         }
1742     }
1743
1744     if (dev->vq[index].handler) {
1745         dev->vq[index].handler(dev, index);
1746     }
1747
1748     return false;
1749 }
1750
1751 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1752 {
1753     vmsg->flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION;
1754     vmsg->size  = sizeof(vmsg->payload.u64);
1755     vmsg->payload.u64 = VHOST_USER_MAX_RAM_SLOTS;
1756     vmsg->fd_num = 0;
1757
1758     if (!vu_message_write(dev, dev->sock, vmsg)) {
1759         vu_panic(dev, "Failed to send max ram slots: %s\n", strerror(errno));
1760     }
1761
1762     DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1763
1764     return false;
1765 }
1766
1767 static bool
1768 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1769 {
1770     int do_reply = 0;
1771
1772     /* Print out generic part of the request. */
1773     DPRINT("================ Vhost user message ================\n");
1774     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1775            vmsg->request);
1776     DPRINT("Flags:   0x%x\n", vmsg->flags);
1777     DPRINT("Size:    %d\n", vmsg->size);
1778
1779     if (vmsg->fd_num) {
1780         int i;
1781         DPRINT("Fds:");
1782         for (i = 0; i < vmsg->fd_num; i++) {
1783             DPRINT(" %d", vmsg->fds[i]);
1784         }
1785         DPRINT("\n");
1786     }
1787
1788     if (dev->iface->process_msg &&
1789         dev->iface->process_msg(dev, vmsg, &do_reply)) {
1790         return do_reply;
1791     }
1792
1793     switch (vmsg->request) {
1794     case VHOST_USER_GET_FEATURES:
1795         return vu_get_features_exec(dev, vmsg);
1796     case VHOST_USER_SET_FEATURES:
1797         return vu_set_features_exec(dev, vmsg);
1798     case VHOST_USER_GET_PROTOCOL_FEATURES:
1799         return vu_get_protocol_features_exec(dev, vmsg);
1800     case VHOST_USER_SET_PROTOCOL_FEATURES:
1801         return vu_set_protocol_features_exec(dev, vmsg);
1802     case VHOST_USER_SET_OWNER:
1803         return vu_set_owner_exec(dev, vmsg);
1804     case VHOST_USER_RESET_OWNER:
1805         return vu_reset_device_exec(dev, vmsg);
1806     case VHOST_USER_SET_MEM_TABLE:
1807         return vu_set_mem_table_exec(dev, vmsg);
1808     case VHOST_USER_SET_LOG_BASE:
1809         return vu_set_log_base_exec(dev, vmsg);
1810     case VHOST_USER_SET_LOG_FD:
1811         return vu_set_log_fd_exec(dev, vmsg);
1812     case VHOST_USER_SET_VRING_NUM:
1813         return vu_set_vring_num_exec(dev, vmsg);
1814     case VHOST_USER_SET_VRING_ADDR:
1815         return vu_set_vring_addr_exec(dev, vmsg);
1816     case VHOST_USER_SET_VRING_BASE:
1817         return vu_set_vring_base_exec(dev, vmsg);
1818     case VHOST_USER_GET_VRING_BASE:
1819         return vu_get_vring_base_exec(dev, vmsg);
1820     case VHOST_USER_SET_VRING_KICK:
1821         return vu_set_vring_kick_exec(dev, vmsg);
1822     case VHOST_USER_SET_VRING_CALL:
1823         return vu_set_vring_call_exec(dev, vmsg);
1824     case VHOST_USER_SET_VRING_ERR:
1825         return vu_set_vring_err_exec(dev, vmsg);
1826     case VHOST_USER_GET_QUEUE_NUM:
1827         return vu_get_queue_num_exec(dev, vmsg);
1828     case VHOST_USER_SET_VRING_ENABLE:
1829         return vu_set_vring_enable_exec(dev, vmsg);
1830     case VHOST_USER_SET_SLAVE_REQ_FD:
1831         return vu_set_slave_req_fd(dev, vmsg);
1832     case VHOST_USER_GET_CONFIG:
1833         return vu_get_config(dev, vmsg);
1834     case VHOST_USER_SET_CONFIG:
1835         return vu_set_config(dev, vmsg);
1836     case VHOST_USER_NONE:
1837         /* if you need processing before exit, override iface->process_msg */
1838         exit(0);
1839     case VHOST_USER_POSTCOPY_ADVISE:
1840         return vu_set_postcopy_advise(dev, vmsg);
1841     case VHOST_USER_POSTCOPY_LISTEN:
1842         return vu_set_postcopy_listen(dev, vmsg);
1843     case VHOST_USER_POSTCOPY_END:
1844         return vu_set_postcopy_end(dev, vmsg);
1845     case VHOST_USER_GET_INFLIGHT_FD:
1846         return vu_get_inflight_fd(dev, vmsg);
1847     case VHOST_USER_SET_INFLIGHT_FD:
1848         return vu_set_inflight_fd(dev, vmsg);
1849     case VHOST_USER_VRING_KICK:
1850         return vu_handle_vring_kick(dev, vmsg);
1851     case VHOST_USER_GET_MAX_MEM_SLOTS:
1852         return vu_handle_get_max_memslots(dev, vmsg);
1853     case VHOST_USER_ADD_MEM_REG:
1854         return vu_add_mem_reg(dev, vmsg);
1855     case VHOST_USER_REM_MEM_REG:
1856         return vu_rem_mem_reg(dev, vmsg);
1857     default:
1858         vmsg_close_fds(vmsg);
1859         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1860     }
1861
1862     return false;
1863 }
1864
1865 bool
1866 vu_dispatch(VuDev *dev)
1867 {
1868     VhostUserMsg vmsg = { 0, };
1869     int reply_requested;
1870     bool need_reply, success = false;
1871
1872     if (!vu_message_read(dev, dev->sock, &vmsg)) {
1873         goto end;
1874     }
1875
1876     need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1877
1878     reply_requested = vu_process_message(dev, &vmsg);
1879     if (!reply_requested && need_reply) {
1880         vmsg_set_reply_u64(&vmsg, 0);
1881         reply_requested = 1;
1882     }
1883
1884     if (!reply_requested) {
1885         success = true;
1886         goto end;
1887     }
1888
1889     if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1890         goto end;
1891     }
1892
1893     success = true;
1894
1895 end:
1896     free(vmsg.data);
1897     return success;
1898 }
1899
1900 void
1901 vu_deinit(VuDev *dev)
1902 {
1903     int i;
1904
1905     for (i = 0; i < dev->nregions; i++) {
1906         VuDevRegion *r = &dev->regions[i];
1907         void *m = (void *) (uintptr_t) r->mmap_addr;
1908         if (m != MAP_FAILED) {
1909             munmap(m, r->size + r->mmap_offset);
1910         }
1911     }
1912     dev->nregions = 0;
1913
1914     for (i = 0; i < dev->max_queues; i++) {
1915         VuVirtq *vq = &dev->vq[i];
1916
1917         if (vq->call_fd != -1) {
1918             close(vq->call_fd);
1919             vq->call_fd = -1;
1920         }
1921
1922         if (vq->kick_fd != -1) {
1923             close(vq->kick_fd);
1924             vq->kick_fd = -1;
1925         }
1926
1927         if (vq->err_fd != -1) {
1928             close(vq->err_fd);
1929             vq->err_fd = -1;
1930         }
1931
1932         if (vq->resubmit_list) {
1933             free(vq->resubmit_list);
1934             vq->resubmit_list = NULL;
1935         }
1936
1937         vq->inflight = NULL;
1938     }
1939
1940     if (dev->inflight_info.addr) {
1941         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1942         dev->inflight_info.addr = NULL;
1943     }
1944
1945     if (dev->inflight_info.fd > 0) {
1946         close(dev->inflight_info.fd);
1947         dev->inflight_info.fd = -1;
1948     }
1949
1950     vu_close_log(dev);
1951     if (dev->slave_fd != -1) {
1952         close(dev->slave_fd);
1953         dev->slave_fd = -1;
1954     }
1955     pthread_mutex_destroy(&dev->slave_mutex);
1956
1957     if (dev->sock != -1) {
1958         close(dev->sock);
1959     }
1960
1961     free(dev->vq);
1962     dev->vq = NULL;
1963 }
1964
1965 bool
1966 vu_init(VuDev *dev,
1967         uint16_t max_queues,
1968         int socket,
1969         vu_panic_cb panic,
1970         vu_set_watch_cb set_watch,
1971         vu_remove_watch_cb remove_watch,
1972         const VuDevIface *iface)
1973 {
1974     uint16_t i;
1975
1976     assert(max_queues > 0);
1977     assert(socket >= 0);
1978     assert(set_watch);
1979     assert(remove_watch);
1980     assert(iface);
1981     assert(panic);
1982
1983     memset(dev, 0, sizeof(*dev));
1984
1985     dev->sock = socket;
1986     dev->panic = panic;
1987     dev->set_watch = set_watch;
1988     dev->remove_watch = remove_watch;
1989     dev->iface = iface;
1990     dev->log_call_fd = -1;
1991     pthread_mutex_init(&dev->slave_mutex, NULL);
1992     dev->slave_fd = -1;
1993     dev->max_queues = max_queues;
1994
1995     dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
1996     if (!dev->vq) {
1997         DPRINT("%s: failed to malloc virtqueues\n", __func__);
1998         return false;
1999     }
2000
2001     for (i = 0; i < max_queues; i++) {
2002         dev->vq[i] = (VuVirtq) {
2003             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2004             .notification = true,
2005         };
2006     }
2007
2008     return true;
2009 }
2010
2011 VuVirtq *
2012 vu_get_queue(VuDev *dev, int qidx)
2013 {
2014     assert(qidx < dev->max_queues);
2015     return &dev->vq[qidx];
2016 }
2017
2018 bool
2019 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2020 {
2021     return vq->enable;
2022 }
2023
2024 bool
2025 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2026 {
2027     return vq->started;
2028 }
2029
2030 static inline uint16_t
2031 vring_avail_flags(VuVirtq *vq)
2032 {
2033     return lduw_le_p(&vq->vring.avail->flags);
2034 }
2035
2036 static inline uint16_t
2037 vring_avail_idx(VuVirtq *vq)
2038 {
2039     vq->shadow_avail_idx = lduw_le_p(&vq->vring.avail->idx);
2040
2041     return vq->shadow_avail_idx;
2042 }
2043
2044 static inline uint16_t
2045 vring_avail_ring(VuVirtq *vq, int i)
2046 {
2047     return lduw_le_p(&vq->vring.avail->ring[i]);
2048 }
2049
2050 static inline uint16_t
2051 vring_get_used_event(VuVirtq *vq)
2052 {
2053     return vring_avail_ring(vq, vq->vring.num);
2054 }
2055
2056 static int
2057 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2058 {
2059     uint16_t num_heads = vring_avail_idx(vq) - idx;
2060
2061     /* Check it isn't doing very strange things with descriptor numbers. */
2062     if (num_heads > vq->vring.num) {
2063         vu_panic(dev, "Guest moved used index from %u to %u",
2064                  idx, vq->shadow_avail_idx);
2065         return -1;
2066     }
2067     if (num_heads) {
2068         /* On success, callers read a descriptor at vq->last_avail_idx.
2069          * Make sure descriptor read does not bypass avail index read. */
2070         smp_rmb();
2071     }
2072
2073     return num_heads;
2074 }
2075
2076 static bool
2077 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2078                    unsigned int idx, unsigned int *head)
2079 {
2080     /* Grab the next descriptor number they're advertising, and increment
2081      * the index we've seen. */
2082     *head = vring_avail_ring(vq, idx % vq->vring.num);
2083
2084     /* If their number is silly, that's a fatal mistake. */
2085     if (*head >= vq->vring.num) {
2086         vu_panic(dev, "Guest says index %u is available", *head);
2087         return false;
2088     }
2089
2090     return true;
2091 }
2092
2093 static int
2094 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2095                              uint64_t addr, size_t len)
2096 {
2097     struct vring_desc *ori_desc;
2098     uint64_t read_len;
2099
2100     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2101         return -1;
2102     }
2103
2104     if (len == 0) {
2105         return -1;
2106     }
2107
2108     while (len) {
2109         read_len = len;
2110         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2111         if (!ori_desc) {
2112             return -1;
2113         }
2114
2115         memcpy(desc, ori_desc, read_len);
2116         len -= read_len;
2117         addr += read_len;
2118         desc += read_len;
2119     }
2120
2121     return 0;
2122 }
2123
2124 enum {
2125     VIRTQUEUE_READ_DESC_ERROR = -1,
2126     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2127     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2128 };
2129
2130 static int
2131 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2132                          int i, unsigned int max, unsigned int *next)
2133 {
2134     /* If this descriptor says it doesn't chain, we're done. */
2135     if (!(lduw_le_p(&desc[i].flags) & VRING_DESC_F_NEXT)) {
2136         return VIRTQUEUE_READ_DESC_DONE;
2137     }
2138
2139     /* Check they're not leading us off end of descriptors. */
2140     *next = lduw_le_p(&desc[i].next);
2141     /* Make sure compiler knows to grab that: we don't want it changing! */
2142     smp_wmb();
2143
2144     if (*next >= max) {
2145         vu_panic(dev, "Desc next is %u", *next);
2146         return VIRTQUEUE_READ_DESC_ERROR;
2147     }
2148
2149     return VIRTQUEUE_READ_DESC_MORE;
2150 }
2151
2152 void
2153 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2154                          unsigned int *out_bytes,
2155                          unsigned max_in_bytes, unsigned max_out_bytes)
2156 {
2157     unsigned int idx;
2158     unsigned int total_bufs, in_total, out_total;
2159     int rc;
2160
2161     idx = vq->last_avail_idx;
2162
2163     total_bufs = in_total = out_total = 0;
2164     if (unlikely(dev->broken) ||
2165         unlikely(!vq->vring.avail)) {
2166         goto done;
2167     }
2168
2169     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2170         unsigned int max, desc_len, num_bufs, indirect = 0;
2171         uint64_t desc_addr, read_len;
2172         struct vring_desc *desc;
2173         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2174         unsigned int i;
2175
2176         max = vq->vring.num;
2177         num_bufs = total_bufs;
2178         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2179             goto err;
2180         }
2181         desc = vq->vring.desc;
2182
2183         if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_INDIRECT) {
2184             if (ldl_le_p(&desc[i].len) % sizeof(struct vring_desc)) {
2185                 vu_panic(dev, "Invalid size for indirect buffer table");
2186                 goto err;
2187             }
2188
2189             /* If we've got too many, that implies a descriptor loop. */
2190             if (num_bufs >= max) {
2191                 vu_panic(dev, "Looped descriptor");
2192                 goto err;
2193             }
2194
2195             /* loop over the indirect descriptor table */
2196             indirect = 1;
2197             desc_addr = ldq_le_p(&desc[i].addr);
2198             desc_len = ldl_le_p(&desc[i].len);
2199             max = desc_len / sizeof(struct vring_desc);
2200             read_len = desc_len;
2201             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2202             if (unlikely(desc && read_len != desc_len)) {
2203                 /* Failed to use zero copy */
2204                 desc = NULL;
2205                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
2206                                                   desc_addr,
2207                                                   desc_len)) {
2208                     desc = desc_buf;
2209                 }
2210             }
2211             if (!desc) {
2212                 vu_panic(dev, "Invalid indirect buffer table");
2213                 goto err;
2214             }
2215             num_bufs = i = 0;
2216         }
2217
2218         do {
2219             /* If we've got too many, that implies a descriptor loop. */
2220             if (++num_bufs > max) {
2221                 vu_panic(dev, "Looped descriptor");
2222                 goto err;
2223             }
2224
2225             if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_WRITE) {
2226                 in_total += ldl_le_p(&desc[i].len);
2227             } else {
2228                 out_total += ldl_le_p(&desc[i].len);
2229             }
2230             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2231                 goto done;
2232             }
2233             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2234         } while (rc == VIRTQUEUE_READ_DESC_MORE);
2235
2236         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2237             goto err;
2238         }
2239
2240         if (!indirect) {
2241             total_bufs = num_bufs;
2242         } else {
2243             total_bufs++;
2244         }
2245     }
2246     if (rc < 0) {
2247         goto err;
2248     }
2249 done:
2250     if (in_bytes) {
2251         *in_bytes = in_total;
2252     }
2253     if (out_bytes) {
2254         *out_bytes = out_total;
2255     }
2256     return;
2257
2258 err:
2259     in_total = out_total = 0;
2260     goto done;
2261 }
2262
2263 bool
2264 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2265                      unsigned int out_bytes)
2266 {
2267     unsigned int in_total, out_total;
2268
2269     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2270                              in_bytes, out_bytes);
2271
2272     return in_bytes <= in_total && out_bytes <= out_total;
2273 }
2274
2275 /* Fetch avail_idx from VQ memory only when we really need to know if
2276  * guest has added some buffers. */
2277 bool
2278 vu_queue_empty(VuDev *dev, VuVirtq *vq)
2279 {
2280     if (unlikely(dev->broken) ||
2281         unlikely(!vq->vring.avail)) {
2282         return true;
2283     }
2284
2285     if (vq->shadow_avail_idx != vq->last_avail_idx) {
2286         return false;
2287     }
2288
2289     return vring_avail_idx(vq) == vq->last_avail_idx;
2290 }
2291
2292 static bool
2293 vring_notify(VuDev *dev, VuVirtq *vq)
2294 {
2295     uint16_t old, new;
2296     bool v;
2297
2298     /* We need to expose used array entries before checking used event. */
2299     smp_mb();
2300
2301     /* Always notify when queue is empty (when feature acknowledge) */
2302     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2303         !vq->inuse && vu_queue_empty(dev, vq)) {
2304         return true;
2305     }
2306
2307     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2308         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2309     }
2310
2311     v = vq->signalled_used_valid;
2312     vq->signalled_used_valid = true;
2313     old = vq->signalled_used;
2314     new = vq->signalled_used = vq->used_idx;
2315     return !v || vring_need_event(vring_get_used_event(vq), new, old);
2316 }
2317
2318 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2319 {
2320     if (unlikely(dev->broken) ||
2321         unlikely(!vq->vring.avail)) {
2322         return;
2323     }
2324
2325     if (!vring_notify(dev, vq)) {
2326         DPRINT("skipped notify...\n");
2327         return;
2328     }
2329
2330     if (vq->call_fd < 0 &&
2331         vu_has_protocol_feature(dev,
2332                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2333         vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2334         VhostUserMsg vmsg = {
2335             .request = VHOST_USER_SLAVE_VRING_CALL,
2336             .flags = VHOST_USER_VERSION,
2337             .size = sizeof(vmsg.payload.state),
2338             .payload.state = {
2339                 .index = vq - dev->vq,
2340             },
2341         };
2342         bool ack = sync &&
2343                    vu_has_protocol_feature(dev,
2344                                            VHOST_USER_PROTOCOL_F_REPLY_ACK);
2345
2346         if (ack) {
2347             vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2348         }
2349
2350         vu_message_write(dev, dev->slave_fd, &vmsg);
2351         if (ack) {
2352             vu_message_read(dev, dev->slave_fd, &vmsg);
2353         }
2354         return;
2355     }
2356
2357     if (eventfd_write(vq->call_fd, 1) < 0) {
2358         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2359     }
2360 }
2361
2362 void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2363 {
2364     _vu_queue_notify(dev, vq, false);
2365 }
2366
2367 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2368 {
2369     _vu_queue_notify(dev, vq, true);
2370 }
2371
2372 static inline void
2373 vring_used_flags_set_bit(VuVirtq *vq, int mask)
2374 {
2375     uint16_t *flags;
2376
2377     flags = (uint16_t *)((char*)vq->vring.used +
2378                          offsetof(struct vring_used, flags));
2379     stw_le_p(flags, lduw_le_p(flags) | mask);
2380 }
2381
2382 static inline void
2383 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2384 {
2385     uint16_t *flags;
2386
2387     flags = (uint16_t *)((char*)vq->vring.used +
2388                          offsetof(struct vring_used, flags));
2389     stw_le_p(flags, lduw_le_p(flags) & ~mask);
2390 }
2391
2392 static inline void
2393 vring_set_avail_event(VuVirtq *vq, uint16_t val)
2394 {
2395     if (!vq->notification) {
2396         return;
2397     }
2398
2399     stw_le_p(&vq->vring.used->ring[vq->vring.num], val);
2400 }
2401
2402 void
2403 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2404 {
2405     vq->notification = enable;
2406     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2407         vring_set_avail_event(vq, vring_avail_idx(vq));
2408     } else if (enable) {
2409         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2410     } else {
2411         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2412     }
2413     if (enable) {
2414         /* Expose avail event/used flags before caller checks the avail idx. */
2415         smp_mb();
2416     }
2417 }
2418
2419 static bool
2420 virtqueue_map_desc(VuDev *dev,
2421                    unsigned int *p_num_sg, struct iovec *iov,
2422                    unsigned int max_num_sg, bool is_write,
2423                    uint64_t pa, size_t sz)
2424 {
2425     unsigned num_sg = *p_num_sg;
2426
2427     assert(num_sg <= max_num_sg);
2428
2429     if (!sz) {
2430         vu_panic(dev, "virtio: zero sized buffers are not allowed");
2431         return false;
2432     }
2433
2434     while (sz) {
2435         uint64_t len = sz;
2436
2437         if (num_sg == max_num_sg) {
2438             vu_panic(dev, "virtio: too many descriptors in indirect table");
2439             return false;
2440         }
2441
2442         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2443         if (iov[num_sg].iov_base == NULL) {
2444             vu_panic(dev, "virtio: invalid address for buffers");
2445             return false;
2446         }
2447         iov[num_sg].iov_len = len;
2448         num_sg++;
2449         sz -= len;
2450         pa += len;
2451     }
2452
2453     *p_num_sg = num_sg;
2454     return true;
2455 }
2456
2457 static void *
2458 virtqueue_alloc_element(size_t sz,
2459                                      unsigned out_num, unsigned in_num)
2460 {
2461     VuVirtqElement *elem;
2462     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2463     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2464     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2465
2466     assert(sz >= sizeof(VuVirtqElement));
2467     elem = malloc(out_sg_end);
2468     elem->out_num = out_num;
2469     elem->in_num = in_num;
2470     elem->in_sg = (void *)elem + in_sg_ofs;
2471     elem->out_sg = (void *)elem + out_sg_ofs;
2472     return elem;
2473 }
2474
2475 static void *
2476 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2477 {
2478     struct vring_desc *desc = vq->vring.desc;
2479     uint64_t desc_addr, read_len;
2480     unsigned int desc_len;
2481     unsigned int max = vq->vring.num;
2482     unsigned int i = idx;
2483     VuVirtqElement *elem;
2484     unsigned int out_num = 0, in_num = 0;
2485     struct iovec iov[VIRTQUEUE_MAX_SIZE];
2486     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2487     int rc;
2488
2489     if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_INDIRECT) {
2490         if (ldl_le_p(&desc[i].len) % sizeof(struct vring_desc)) {
2491             vu_panic(dev, "Invalid size for indirect buffer table");
2492             return NULL;
2493         }
2494
2495         /* loop over the indirect descriptor table */
2496         desc_addr = ldq_le_p(&desc[i].addr);
2497         desc_len = ldl_le_p(&desc[i].len);
2498         max = desc_len / sizeof(struct vring_desc);
2499         read_len = desc_len;
2500         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2501         if (unlikely(desc && read_len != desc_len)) {
2502             /* Failed to use zero copy */
2503             desc = NULL;
2504             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2505                                               desc_addr,
2506                                               desc_len)) {
2507                 desc = desc_buf;
2508             }
2509         }
2510         if (!desc) {
2511             vu_panic(dev, "Invalid indirect buffer table");
2512             return NULL;
2513         }
2514         i = 0;
2515     }
2516
2517     /* Collect all the descriptors */
2518     do {
2519         if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_WRITE) {
2520             if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2521                                VIRTQUEUE_MAX_SIZE - out_num, true,
2522                                ldq_le_p(&desc[i].addr),
2523                                ldl_le_p(&desc[i].len))) {
2524                 return NULL;
2525             }
2526         } else {
2527             if (in_num) {
2528                 vu_panic(dev, "Incorrect order for descriptors");
2529                 return NULL;
2530             }
2531             if (!virtqueue_map_desc(dev, &out_num, iov,
2532                                VIRTQUEUE_MAX_SIZE, false,
2533                                ldq_le_p(&desc[i].addr),
2534                                ldl_le_p(&desc[i].len))) {
2535                 return NULL;
2536             }
2537         }
2538
2539         /* If we've got too many, that implies a descriptor loop. */
2540         if ((in_num + out_num) > max) {
2541             vu_panic(dev, "Looped descriptor");
2542             return NULL;
2543         }
2544         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2545     } while (rc == VIRTQUEUE_READ_DESC_MORE);
2546
2547     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2548         vu_panic(dev, "read descriptor error");
2549         return NULL;
2550     }
2551
2552     /* Now copy what we have collected and mapped */
2553     elem = virtqueue_alloc_element(sz, out_num, in_num);
2554     elem->index = idx;
2555     for (i = 0; i < out_num; i++) {
2556         elem->out_sg[i] = iov[i];
2557     }
2558     for (i = 0; i < in_num; i++) {
2559         elem->in_sg[i] = iov[out_num + i];
2560     }
2561
2562     return elem;
2563 }
2564
2565 static int
2566 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2567 {
2568     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2569         return 0;
2570     }
2571
2572     if (unlikely(!vq->inflight)) {
2573         return -1;
2574     }
2575
2576     vq->inflight->desc[desc_idx].counter = vq->counter++;
2577     vq->inflight->desc[desc_idx].inflight = 1;
2578
2579     return 0;
2580 }
2581
2582 static int
2583 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2584 {
2585     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2586         return 0;
2587     }
2588
2589     if (unlikely(!vq->inflight)) {
2590         return -1;
2591     }
2592
2593     vq->inflight->last_batch_head = desc_idx;
2594
2595     return 0;
2596 }
2597
2598 static int
2599 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2600 {
2601     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2602         return 0;
2603     }
2604
2605     if (unlikely(!vq->inflight)) {
2606         return -1;
2607     }
2608
2609     barrier();
2610
2611     vq->inflight->desc[desc_idx].inflight = 0;
2612
2613     barrier();
2614
2615     vq->inflight->used_idx = vq->used_idx;
2616
2617     return 0;
2618 }
2619
2620 void *
2621 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2622 {
2623     int i;
2624     unsigned int head;
2625     VuVirtqElement *elem;
2626
2627     if (unlikely(dev->broken) ||
2628         unlikely(!vq->vring.avail)) {
2629         return NULL;
2630     }
2631
2632     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2633         i = (--vq->resubmit_num);
2634         elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2635
2636         if (!vq->resubmit_num) {
2637             free(vq->resubmit_list);
2638             vq->resubmit_list = NULL;
2639         }
2640
2641         return elem;
2642     }
2643
2644     if (vu_queue_empty(dev, vq)) {
2645         return NULL;
2646     }
2647     /*
2648      * Needed after virtio_queue_empty(), see comment in
2649      * virtqueue_num_heads().
2650      */
2651     smp_rmb();
2652
2653     if (vq->inuse >= vq->vring.num) {
2654         vu_panic(dev, "Virtqueue size exceeded");
2655         return NULL;
2656     }
2657
2658     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2659         return NULL;
2660     }
2661
2662     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2663         vring_set_avail_event(vq, vq->last_avail_idx);
2664     }
2665
2666     elem = vu_queue_map_desc(dev, vq, head, sz);
2667
2668     if (!elem) {
2669         return NULL;
2670     }
2671
2672     vq->inuse++;
2673
2674     vu_queue_inflight_get(dev, vq, head);
2675
2676     return elem;
2677 }
2678
2679 static void
2680 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2681                         size_t len)
2682 {
2683     vq->inuse--;
2684     /* unmap, when DMA support is added */
2685 }
2686
2687 void
2688 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2689                size_t len)
2690 {
2691     vq->last_avail_idx--;
2692     vu_queue_detach_element(dev, vq, elem, len);
2693 }
2694
2695 bool
2696 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2697 {
2698     if (num > vq->inuse) {
2699         return false;
2700     }
2701     vq->last_avail_idx -= num;
2702     vq->inuse -= num;
2703     return true;
2704 }
2705
2706 static inline
2707 void vring_used_write(VuDev *dev, VuVirtq *vq,
2708                       struct vring_used_elem *uelem, int i)
2709 {
2710     struct vring_used *used = vq->vring.used;
2711
2712     used->ring[i] = *uelem;
2713     vu_log_write(dev, vq->vring.log_guest_addr +
2714                  offsetof(struct vring_used, ring[i]),
2715                  sizeof(used->ring[i]));
2716 }
2717
2718
2719 static void
2720 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2721                   const VuVirtqElement *elem,
2722                   unsigned int len)
2723 {
2724     struct vring_desc *desc = vq->vring.desc;
2725     unsigned int i, max, min, desc_len;
2726     uint64_t desc_addr, read_len;
2727     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2728     unsigned num_bufs = 0;
2729
2730     max = vq->vring.num;
2731     i = elem->index;
2732
2733     if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_INDIRECT) {
2734         if (ldl_le_p(&desc[i].len) % sizeof(struct vring_desc)) {
2735             vu_panic(dev, "Invalid size for indirect buffer table");
2736             return;
2737         }
2738
2739         /* loop over the indirect descriptor table */
2740         desc_addr = ldq_le_p(&desc[i].addr);
2741         desc_len = ldl_le_p(&desc[i].len);
2742         max = desc_len / sizeof(struct vring_desc);
2743         read_len = desc_len;
2744         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2745         if (unlikely(desc && read_len != desc_len)) {
2746             /* Failed to use zero copy */
2747             desc = NULL;
2748             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2749                                               desc_addr,
2750                                               desc_len)) {
2751                 desc = desc_buf;
2752             }
2753         }
2754         if (!desc) {
2755             vu_panic(dev, "Invalid indirect buffer table");
2756             return;
2757         }
2758         i = 0;
2759     }
2760
2761     do {
2762         if (++num_bufs > max) {
2763             vu_panic(dev, "Looped descriptor");
2764             return;
2765         }
2766
2767         if (lduw_le_p(&desc[i].flags) & VRING_DESC_F_WRITE) {
2768             min = MIN(ldl_le_p(&desc[i].len), len);
2769             vu_log_write(dev, ldq_le_p(&desc[i].addr), min);
2770             len -= min;
2771         }
2772
2773     } while (len > 0 &&
2774              (virtqueue_read_next_desc(dev, desc, i, max, &i)
2775               == VIRTQUEUE_READ_DESC_MORE));
2776 }
2777
2778 void
2779 vu_queue_fill(VuDev *dev, VuVirtq *vq,
2780               const VuVirtqElement *elem,
2781               unsigned int len, unsigned int idx)
2782 {
2783     struct vring_used_elem uelem;
2784
2785     if (unlikely(dev->broken) ||
2786         unlikely(!vq->vring.avail)) {
2787         return;
2788     }
2789
2790     vu_log_queue_fill(dev, vq, elem, len);
2791
2792     idx = (idx + vq->used_idx) % vq->vring.num;
2793
2794     stl_le_p(&uelem.id, elem->index);
2795     stl_le_p(&uelem.len, len);
2796     vring_used_write(dev, vq, &uelem, idx);
2797 }
2798
2799 static inline
2800 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2801 {
2802     stw_le_p(&vq->vring.used->idx, val);
2803     vu_log_write(dev,
2804                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2805                  sizeof(vq->vring.used->idx));
2806
2807     vq->used_idx = val;
2808 }
2809
2810 void
2811 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2812 {
2813     uint16_t old, new;
2814
2815     if (unlikely(dev->broken) ||
2816         unlikely(!vq->vring.avail)) {
2817         return;
2818     }
2819
2820     /* Make sure buffer is written before we update index. */
2821     smp_wmb();
2822
2823     old = vq->used_idx;
2824     new = old + count;
2825     vring_used_idx_set(dev, vq, new);
2826     vq->inuse -= count;
2827     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2828         vq->signalled_used_valid = false;
2829     }
2830 }
2831
2832 void
2833 vu_queue_push(VuDev *dev, VuVirtq *vq,
2834               const VuVirtqElement *elem, unsigned int len)
2835 {
2836     vu_queue_fill(dev, vq, elem, len, 0);
2837     vu_queue_inflight_pre_put(dev, vq, elem->index);
2838     vu_queue_flush(dev, vq, 1);
2839     vu_queue_inflight_post_put(dev, vq, elem->index);
2840 }