OSDN Git Service

soc: qcom: hab: hold the message when the read buffer is smaller
authorYong Ding <yongding@codeaurora.org>
Mon, 23 Apr 2018 11:04:39 +0000 (19:04 +0800)
committerYong Ding <yongding@codeaurora.org>
Wed, 9 May 2018 11:31:48 +0000 (19:31 +0800)
If the receive buffer in habmm_socket_recv() is smaller
than the message, it should be kept in the queue rather
than dropped.

Change-Id: Iabc1f73e5b443cc7ebdefc3961d5bec9049a265f
Signed-off-by: Yong Ding <yongding@codeaurora.org>
drivers/soc/qcom/hab/hab.c
drivers/soc/qcom/hab/hab.h
drivers/soc/qcom/hab/hab_msg.c
drivers/soc/qcom/hab/khab.c

index 3294fc3..37afe02 100644 (file)
@@ -356,18 +356,21 @@ err:
        return ret;
 }
 
-struct hab_message *hab_vchan_recv(struct uhab_context *ctx,
+int hab_vchan_recv(struct uhab_context *ctx,
+                               struct hab_message **message,
                                int vcid,
+                               int *rsize,
                                unsigned int flags)
 {
        struct virtual_channel *vchan;
-       struct hab_message *message;
        int ret = 0;
        int nonblocking_flag = flags & HABMM_SOCKET_RECV_FLAGS_NON_BLOCKING;
 
        vchan = hab_get_vchan_fromvcid(vcid, ctx);
-       if (!vchan)
-               return ERR_PTR(-ENODEV);
+       if (!vchan) {
+               pr_err("vcid %X, vchan %p ctx %p\n", vcid, vchan, ctx);
+               return -ENODEV;
+       }
 
        if (nonblocking_flag) {
                /*
@@ -378,18 +381,18 @@ struct hab_message *hab_vchan_recv(struct uhab_context *ctx,
                physical_channel_rx_dispatch((unsigned long) vchan->pchan);
        }
 
-       message = hab_msg_dequeue(vchan, flags);
-       if (!message) {
+       ret = hab_msg_dequeue(vchan, message, rsize, flags);
+       if (!(*message)) {
                if (nonblocking_flag)
                        ret = -EAGAIN;
                else if (vchan->otherend_closed)
                        ret = -ENODEV;
-               else
-                       ret = -EPIPE;
+               else if (ret == -ERESTARTSYS)
+                       ret = -EINTR;
        }
 
        hab_vchan_put(vchan);
-       return ret ? ERR_PTR(ret) : message;
+       return ret;
 }
 
 bool hab_is_loopback(void)
@@ -843,29 +846,22 @@ static long hab_ioctl(struct file *filep, unsigned int cmd, unsigned long arg)
                        break;
                }
 
-               msg = hab_vchan_recv(ctx, recv_param->vcid, recv_param->flags);
-
-               if (IS_ERR(msg)) {
-                       recv_param->sizebytes = 0;
-                       ret = PTR_ERR(msg);
-                       break;
-               }
+               ret = hab_vchan_recv(ctx, &msg, recv_param->vcid,
+                               &recv_param->sizebytes, recv_param->flags);
 
-               if (recv_param->sizebytes < msg->sizebytes) {
-                       recv_param->sizebytes = 0;
-                       ret = -EINVAL;
-               } else if (copy_to_user((void __user *)recv_param->data,
+               if (ret == 0 && msg) {
+                       if (copy_to_user((void __user *)recv_param->data,
                                        msg->data,
                                        msg->sizebytes)) {
-                       pr_err("copy_to_user failed: vc=%x size=%d\n",
-                               recv_param->vcid, (int)msg->sizebytes);
-                       recv_param->sizebytes = 0;
-                       ret = -EFAULT;
-               } else {
-                       recv_param->sizebytes = msg->sizebytes;
+                               pr_err("copy_to_user failed: vc=%x size=%d\n",
+                                  recv_param->vcid, (int)msg->sizebytes);
+                               recv_param->sizebytes = 0;
+                               ret = -EFAULT;
+                       }
                }
 
-               hab_msg_free(msg);
+               if (msg)
+                       hab_msg_free(msg);
                break;
        case IOCTL_HAB_VC_EXPORT:
                ret = hab_mem_export(ctx, (struct hab_export *)data, 0);
index ffb0637..2a07da7 100644 (file)
@@ -373,9 +373,11 @@ long hab_vchan_send(struct uhab_context *ctx,
                size_t sizebytes,
                void *data,
                unsigned int flags);
-struct hab_message *hab_vchan_recv(struct uhab_context *ctx,
-                               int vcid,
-                               unsigned int flags);
+int hab_vchan_recv(struct uhab_context *ctx,
+               struct hab_message **msg,
+               int vcid,
+               int *rsize,
+               unsigned int flags);
 void hab_vchan_stop(struct virtual_channel *vchan);
 void hab_vchans_stop(struct physical_channel *pchan);
 void hab_vchan_stop_notify(struct virtual_channel *vchan);
@@ -422,8 +424,8 @@ int habmem_imp_hyp_mmap(struct file *flip, struct vm_area_struct *vma);
 
 
 void hab_msg_free(struct hab_message *message);
-struct hab_message *hab_msg_dequeue(struct virtual_channel *vchan,
-               unsigned int flags);
+int hab_msg_dequeue(struct virtual_channel *vchan,
+               struct hab_message **msg, int *rsize, unsigned int flags);
 
 void hab_msg_recv(struct physical_channel *pchan,
                struct hab_header *header);
index d5c625e..d904cde 100644 (file)
@@ -42,8 +42,9 @@ void hab_msg_free(struct hab_message *message)
        kfree(message);
 }
 
-struct hab_message *
-hab_msg_dequeue(struct virtual_channel *vchan, unsigned int flags)
+int
+hab_msg_dequeue(struct virtual_channel *vchan, struct hab_message **msg,
+               int *rsize, unsigned int flags)
 {
        struct hab_message *message = NULL;
        int ret = 0;
@@ -64,15 +65,30 @@ hab_msg_dequeue(struct virtual_channel *vchan, unsigned int flags)
        }
 
        /* return all the received messages before the remote close */
-       if (!ret && !hab_rx_queue_empty(vchan)) {
+       if ((!ret || (ret == -ERESTARTSYS)) && !hab_rx_queue_empty(vchan)) {
                spin_lock_bh(&vchan->rx_lock);
                message = list_first_entry(&vchan->rx_list,
                                struct hab_message, node);
-               list_del(&message->node);
+               if (message) {
+                       if (*rsize >= message->sizebytes) {
+                               /* msg can be safely retrieved in full */
+                               list_del(&message->node);
+                               ret = 0;
+                               *rsize = message->sizebytes;
+                       } else {
+                               pr_err("rcv buffer too small %d < %zd\n",
+                                          *rsize, message->sizebytes);
+                               *rsize = 0;
+                               message = NULL;
+                               ret = -EINVAL;
+                       }
+               }
                spin_unlock_bh(&vchan->rx_lock);
-       }
+       } else
+               *rsize = 0;
 
-       return message;
+       *msg = message;
+       return ret;
 }
 
 static void hab_msg_queue(struct virtual_channel *vchan,
index 3fdd11f..ba77e5e 100644 (file)
@@ -51,22 +51,14 @@ int32_t habmm_socket_recv(int32_t handle, void *dst_buff, uint32_t *size_bytes,
        if (!size_bytes || !dst_buff)
                return -EINVAL;
 
-       msg = hab_vchan_recv(hab_driver.kctx, handle, flags);
+       ret = hab_vchan_recv(hab_driver.kctx, &msg, handle, size_bytes, flags);
 
-       if (IS_ERR(msg)) {
-               *size_bytes = 0;
-               return PTR_ERR(msg);
-       }
-
-       if (*size_bytes < msg->sizebytes) {
-               *size_bytes = 0;
-               ret = -EINVAL;
-       } else {
+       if (ret == 0 && msg)
                memcpy(dst_buff, msg->data, msg->sizebytes);
-               *size_bytes = msg->sizebytes;
-       }
 
-       hab_msg_free(msg);
+       if (msg)
+               hab_msg_free(msg);
+
        return ret;
 }
 EXPORT_SYMBOL(habmm_socket_recv);