OSDN Git Service

habanalabs: change wait_for_interrupt implementation
authorfarah kassabri <fkassabri@habana.ai>
Tue, 2 Nov 2021 09:34:18 +0000 (11:34 +0200)
committerOded Gabbay <ogabbay@kernel.org>
Sun, 26 Dec 2021 06:59:09 +0000 (08:59 +0200)
Currently the cq counters are allocated in userspace memory,
and mapped by the driver to the device address space.

A new requirement that is part of new future API related to this one,
requires that cq counters will be allocated in kernel memory.

We leverage the existing cb_create API with KERNEL_MAPPED flag set to
allocate this memory.

That way we gain two things:
1. The memory cannot be freed while in use since it's protected
by refcount in driver.

2. No need to wake up the user thread upon each interrupt from CQ,
because the kernel has direct access to the counter. Therefore,
it can make comparison with the target value in the interrupt
handler and wake up the user thread only if the counter reaches the
target value. This is instead of waking the thread up to copy counter
value from user then go sleep again if target value wasn't reached.

Signed-off-by: farah kassabri <fkassabri@habana.ai>
Reviewed-by: Oded Gabbay <ogabbay@kernel.org>
Signed-off-by: Oded Gabbay <ogabbay@kernel.org>
drivers/misc/habanalabs/common/command_buffer.c
drivers/misc/habanalabs/common/command_submission.c
drivers/misc/habanalabs/common/habanalabs.h
drivers/misc/habanalabs/common/irq.c
include/uapi/misc/habanalabs.h

index c591f04..d4eb9fb 100644 (file)
@@ -380,8 +380,9 @@ int hl_cb_destroy(struct hl_device *hdev, struct hl_cb_mgr *mgr, u64 cb_handle)
 }
 
 static int hl_cb_info(struct hl_device *hdev, struct hl_cb_mgr *mgr,
-                       u64 cb_handle, u32 *usage_cnt)
+                       u64 cb_handle, u32 flags, u32 *usage_cnt, u64 *device_va)
 {
+       struct hl_vm_va_block *va_block;
        struct hl_cb *cb;
        u32 handle;
        int rc = 0;
@@ -402,7 +403,18 @@ static int hl_cb_info(struct hl_device *hdev, struct hl_cb_mgr *mgr,
                goto out;
        }
 
-       *usage_cnt = atomic_read(&cb->cs_cnt);
+       if (flags & HL_CB_FLAGS_GET_DEVICE_VA) {
+               va_block = list_first_entry(&cb->va_block_list, struct hl_vm_va_block, node);
+               if (va_block) {
+                       *device_va = va_block->start;
+               } else {
+                       dev_err(hdev->dev, "CB is not mapped to the device's MMU\n");
+                       rc = -EINVAL;
+                       goto out;
+               }
+       } else {
+               *usage_cnt = atomic_read(&cb->cs_cnt);
+       }
 
 out:
        spin_unlock(&mgr->cb_lock);
@@ -414,7 +426,7 @@ int hl_cb_ioctl(struct hl_fpriv *hpriv, void *data)
        union hl_cb_args *args = data;
        struct hl_device *hdev = hpriv->hdev;
        enum hl_device_status status;
-       u64 handle = 0;
+       u64 handle = 0, device_va;
        u32 usage_cnt = 0;
        int rc;
 
@@ -450,9 +462,16 @@ int hl_cb_ioctl(struct hl_fpriv *hpriv, void *data)
 
        case HL_CB_OP_INFO:
                rc = hl_cb_info(hdev, &hpriv->cb_mgr, args->in.cb_handle,
-                               &usage_cnt);
-               memset(args, 0, sizeof(*args));
-               args->out.usage_cnt = usage_cnt;
+                               args->in.flags,
+                               &usage_cnt,
+                               &device_va);
+
+               memset(&args->out, 0, sizeof(args->out));
+
+               if (args->in.flags & HL_CB_FLAGS_GET_DEVICE_VA)
+                       args->out.device_va = device_va;
+               else
+                       args->out.usage_cnt = usage_cnt;
                break;
 
        default:
index b9fed6b..7073fa6 100644 (file)
@@ -2845,6 +2845,106 @@ static int hl_cs_wait_ioctl(struct hl_fpriv *hpriv, void *data)
 }
 
 static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
+                               struct hl_cb_mgr *cb_mgr, u64 timeout_us,
+                               u64 cq_counters_handle, u64 cq_counters_offset,
+                               u64 target_value, struct hl_user_interrupt *interrupt,
+                               u32 *status,
+                               u64 *timestamp)
+{
+       struct hl_user_pending_interrupt *pend;
+       unsigned long timeout, flags;
+       long completion_rc;
+       struct hl_cb *cb;
+       int rc = 0;
+       u32 handle;
+
+       timeout = hl_usecs64_to_jiffies(timeout_us);
+
+       hl_ctx_get(hdev, ctx);
+
+       cq_counters_handle >>= PAGE_SHIFT;
+       handle = (u32) cq_counters_handle;
+
+       cb = hl_cb_get(hdev, cb_mgr, handle);
+       if (!cb) {
+               hl_ctx_put(ctx);
+               return -EINVAL;
+       }
+
+       pend = kzalloc(sizeof(*pend), GFP_KERNEL);
+       if (!pend) {
+               hl_cb_put(cb);
+               hl_ctx_put(ctx);
+               return -ENOMEM;
+       }
+
+       hl_fence_init(&pend->fence, ULONG_MAX);
+
+       pend->cq_kernel_addr = (u64 *) cb->kernel_address + cq_counters_offset;
+       pend->cq_target_value = target_value;
+
+       /* We check for completion value as interrupt could have been received
+        * before we added the node to the wait list
+        */
+       if (*pend->cq_kernel_addr >= target_value) {
+               *status = HL_WAIT_CS_STATUS_COMPLETED;
+               /* There was no interrupt, we assume the completion is now. */
+               pend->fence.timestamp = ktime_get();
+       }
+
+       if (!timeout_us || (*status == HL_WAIT_CS_STATUS_COMPLETED))
+               goto set_timestamp;
+
+       /* Add pending user interrupt to relevant list for the interrupt
+        * handler to monitor
+        */
+       spin_lock_irqsave(&interrupt->wait_list_lock, flags);
+       list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
+       spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
+
+       /* Wait for interrupt handler to signal completion */
+       completion_rc = wait_for_completion_interruptible_timeout(&pend->fence.completion,
+                                                               timeout);
+       if (completion_rc > 0) {
+               *status = HL_WAIT_CS_STATUS_COMPLETED;
+       } else {
+               if (completion_rc == -ERESTARTSYS) {
+                       dev_err_ratelimited(hdev->dev,
+                                       "user process got signal while waiting for interrupt ID %d\n",
+                                       interrupt->interrupt_id);
+                       rc = -EINTR;
+                       *status = HL_WAIT_CS_STATUS_ABORTED;
+               } else {
+                       if (pend->fence.error == -EIO) {
+                               dev_err_ratelimited(hdev->dev,
+                                               "interrupt based wait ioctl aborted(error:%d) due to a reset cycle initiated\n",
+                                               pend->fence.error);
+                               rc = -EIO;
+                               *status = HL_WAIT_CS_STATUS_ABORTED;
+                       } else {
+                               dev_err_ratelimited(hdev->dev, "Waiting for interrupt ID %d timedout\n",
+                                               interrupt->interrupt_id);
+                               rc = -ETIMEDOUT;
+                       }
+                       *status = HL_WAIT_CS_STATUS_BUSY;
+               }
+       }
+
+       spin_lock_irqsave(&interrupt->wait_list_lock, flags);
+       list_del(&pend->wait_list_node);
+       spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
+
+set_timestamp:
+       *timestamp = ktime_to_ns(pend->fence.timestamp);
+
+       kfree(pend);
+       hl_cb_put(cb);
+       hl_ctx_put(ctx);
+
+       return rc;
+}
+
+static int _hl_interrupt_wait_ioctl_user_addr(struct hl_device *hdev, struct hl_ctx *ctx,
                                u64 timeout_us, u64 user_address,
                                u64 target_value, struct hl_user_interrupt *interrupt,
 
@@ -2861,7 +2961,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
 
        hl_ctx_get(hdev, ctx);
 
-       pend = kmalloc(sizeof(*pend), GFP_KERNEL);
+       pend = kzalloc(sizeof(*pend), GFP_KERNEL);
        if (!pend) {
                hl_ctx_put(ctx);
                return -ENOMEM;
@@ -2990,7 +3090,14 @@ static int hl_interrupt_wait_ioctl(struct hl_fpriv *hpriv, void *data)
        else
                interrupt = &hdev->user_interrupt[interrupt_id - first_interrupt];
 
-       rc = _hl_interrupt_wait_ioctl(hdev, hpriv->ctx,
+       if (args->in.flags & HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ)
+               rc = _hl_interrupt_wait_ioctl(hdev, hpriv->ctx, &hpriv->cb_mgr,
+                               args->in.interrupt_timeout_us, args->in.cq_counters_handle,
+                               args->in.cq_counters_offset,
+                               args->in.target, interrupt, &status,
+                               &timestamp);
+       else
+               rc = _hl_interrupt_wait_ioctl_user_addr(hdev, hpriv->ctx,
                                args->in.interrupt_timeout_us, args->in.addr,
                                args->in.target, interrupt, &status,
                                &timestamp);
index 4d49861..78772fe 100644 (file)
@@ -876,10 +876,15 @@ struct hl_user_interrupt {
  *                                    pending on an interrupt
  * @wait_list_node: node in the list of user threads pending on an interrupt
  * @fence: hl fence object for interrupt completion
+ * @cq_target_value: CQ target value
+ * @cq_kernel_addr: CQ kernel address, to be used in the cq interrupt
+ *                  handler for taget value comparison
  */
 struct hl_user_pending_interrupt {
        struct list_head        wait_list_node;
        struct hl_fence         fence;
+       u64                     cq_target_value;
+       u64                     *cq_kernel_addr;
 };
 
 /**
index 64e0d9d..6454ea1 100644 (file)
@@ -145,8 +145,12 @@ static void handle_user_cq(struct hl_device *hdev,
 
        spin_lock(&user_cq->wait_list_lock);
        list_for_each_entry(pend, &user_cq->wait_list_head, wait_list_node) {
-               pend->fence.timestamp = now;
-               complete_all(&pend->fence.completion);
+               if ((pend->cq_kernel_addr &&
+                               *(pend->cq_kernel_addr) >= pend->cq_target_value) ||
+                               !pend->cq_kernel_addr) {
+                       pend->fence.timestamp = now;
+                       complete_all(&pend->fence.completion);
+               }
        }
        spin_unlock(&user_cq->wait_list_lock);
 }
index 648850b..371dfc4 100644 (file)
@@ -680,7 +680,10 @@ struct hl_info_args {
 #define HL_MAX_CB_SIZE         (0x200000 - 32)
 
 /* Indicates whether the command buffer should be mapped to the device's MMU */
-#define HL_CB_FLAGS_MAP                0x1
+#define HL_CB_FLAGS_MAP                        0x1
+
+/* Used with HL_CB_OP_INFO opcode to get the device va address for kernel mapped CB */
+#define HL_CB_FLAGS_GET_DEVICE_VA      0x2
 
 struct hl_cb_in {
        /* Handle of CB or 0 if we want to create one */
@@ -702,11 +705,16 @@ struct hl_cb_out {
                /* Handle of CB */
                __u64 cb_handle;
 
-               /* Information about CB */
-               struct {
-                       /* Usage count of CB */
-                       __u32 usage_cnt;
-                       __u32 pad;
+               union {
+                       /* Information about CB */
+                       struct {
+                               /* Usage count of CB */
+                               __u32 usage_cnt;
+                               __u32 pad;
+                       };
+
+                       /* CB mapped address to device MMU */
+                       __u64 device_va;
                };
        };
 };
@@ -947,9 +955,10 @@ union hl_cs_args {
        struct hl_cs_out out;
 };
 
-#define HL_WAIT_CS_FLAGS_INTERRUPT     0x2
-#define HL_WAIT_CS_FLAGS_INTERRUPT_MASK 0xFFF00000
-#define HL_WAIT_CS_FLAGS_MULTI_CS      0x4
+#define HL_WAIT_CS_FLAGS_INTERRUPT             0x2
+#define HL_WAIT_CS_FLAGS_INTERRUPT_MASK                0xFFF00000
+#define HL_WAIT_CS_FLAGS_MULTI_CS              0x4
+#define HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ   0x10
 
 #define HL_WAIT_MULTI_CS_LIST_MAX_LEN  32
 
@@ -969,14 +978,23 @@ struct hl_wait_cs_in {
                };
 
                struct {
-                       /* User address for completion comparison.
-                        * upon interrupt, driver will compare the value pointed
-                        * by this address with the supplied target value.
-                        * in order not to perform any comparison, set address
-                        * to all 1s.
-                        * Relevant only when HL_WAIT_CS_FLAGS_INTERRUPT is set
-                        */
-                       __u64 addr;
+                       union {
+                               /* User address for completion comparison.
+                                * upon interrupt, driver will compare the value pointed
+                                * by this address with the supplied target value.
+                                * in order not to perform any comparison, set address
+                                * to all 1s.
+                                * Relevant only when HL_WAIT_CS_FLAGS_INTERRUPT is set
+                                */
+                               __u64 addr;
+
+                               /* cq_counters_handle to a kernel mapped cb which contains
+                                * cq counters.
+                                * Relevant only when HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ is set
+                                */
+                               __u64 cq_counters_handle;
+                       };
+
                        /* Target value for completion comparison */
                        __u64 target;
                };
@@ -1004,6 +1022,15 @@ struct hl_wait_cs_in {
                 */
                __u64 interrupt_timeout_us;
        };
+
+       /*
+        * cq counter offset inside the counters cb pointed by cq_counters_handle above.
+        * upon interrupt, driver will compare the value pointed
+        * by this address (cq_counters_handle + cq_counters_offset)
+        * with the supplied target value.
+        * relevant only when HL_WAIT_CS_FLAGS_INTERRUPT_KERNEL_CQ is set
+        */
+       __u64 cq_counters_offset;
 };
 
 #define HL_WAIT_CS_STATUS_COMPLETED    0