OSDN Git Service

Merge tag 'kvmarm-fixes-for-5.1' of git://git.kernel.org/pub/scm/linux/kernel/git...
[uclinux-h8/linux.git] / drivers / misc / habanalabs / mmu.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 /*
4  * Copyright 2016-2019 HabanaLabs, Ltd.
5  * All Rights Reserved.
6  */
7
8 #include "habanalabs.h"
9 #include "include/hw_ip/mmu/mmu_general.h"
10
11 #include <linux/genalloc.h>
12 #include <linux/slab.h>
13
14 static struct pgt_info *get_pgt_info(struct hl_ctx *ctx, u64 addr)
15 {
16         struct pgt_info *pgt_info = NULL;
17
18         hash_for_each_possible(ctx->mmu_hash, pgt_info, node,
19                                 (unsigned long) addr)
20                 if (addr == pgt_info->addr)
21                         break;
22
23         return pgt_info;
24 }
25
26 static void free_hop(struct hl_ctx *ctx, u64 hop_addr)
27 {
28         struct pgt_info *pgt_info = get_pgt_info(ctx, hop_addr);
29
30         gen_pool_free(pgt_info->ctx->hdev->mmu_pgt_pool, pgt_info->addr,
31                         ctx->hdev->asic_prop.mmu_hop_table_size);
32         hash_del(&pgt_info->node);
33
34         kfree(pgt_info);
35 }
36
37 static u64 alloc_hop(struct hl_ctx *ctx)
38 {
39         struct hl_device *hdev = ctx->hdev;
40         struct pgt_info *pgt_info;
41         u64 addr;
42
43         pgt_info = kmalloc(sizeof(*pgt_info), GFP_KERNEL);
44         if (!pgt_info)
45                 return ULLONG_MAX;
46
47         addr = (u64) gen_pool_alloc(hdev->mmu_pgt_pool,
48                         hdev->asic_prop.mmu_hop_table_size);
49         if (!addr) {
50                 dev_err(hdev->dev, "failed to allocate page\n");
51                 kfree(pgt_info);
52                 return ULLONG_MAX;
53         }
54
55         pgt_info->addr = addr;
56         pgt_info->ctx = ctx;
57         pgt_info->num_of_ptes = 0;
58         hash_add(ctx->mmu_hash, &pgt_info->node, addr);
59
60         return addr;
61 }
62
63 static inline void clear_pte(struct hl_device *hdev, u64 pte_addr)
64 {
65         /* clear the last and present bits */
66         hdev->asic_funcs->write_pte(hdev, pte_addr, 0);
67 }
68
69 static inline void get_pte(struct hl_ctx *ctx, u64 hop_addr)
70 {
71         get_pgt_info(ctx, hop_addr)->num_of_ptes++;
72 }
73
74 /*
75  * put_pte - decrement the num of ptes and free the hop if possible
76  *
77  * @ctx: pointer to the context structure
78  * @hop_addr: addr of the hop
79  *
80  * This function returns the number of ptes left on this hop. If the number is
81  * 0, it means the pte was freed.
82  */
83 static inline int put_pte(struct hl_ctx *ctx, u64 hop_addr)
84 {
85         struct pgt_info *pgt_info = get_pgt_info(ctx, hop_addr);
86         int num_of_ptes_left;
87
88         pgt_info->num_of_ptes--;
89
90         /*
91          * Need to save the number of ptes left because free_hop might free
92          * the pgt_info
93          */
94         num_of_ptes_left = pgt_info->num_of_ptes;
95         if (!num_of_ptes_left)
96                 free_hop(ctx, hop_addr);
97
98         return num_of_ptes_left;
99 }
100
101 static inline u64 get_hop0_addr(struct hl_ctx *ctx)
102 {
103         return ctx->hdev->asic_prop.mmu_pgt_addr +
104                         (ctx->asid * ctx->hdev->asic_prop.mmu_hop_table_size);
105 }
106
107 static inline u64 get_hopN_pte_addr(struct hl_ctx *ctx, u64 hop_addr,
108                                         u64 virt_addr, u64 mask, u64 shift)
109 {
110         return hop_addr + ctx->hdev->asic_prop.mmu_pte_size *
111                         ((virt_addr & mask) >> shift);
112 }
113
114 static inline u64 get_hop0_pte_addr(struct hl_ctx *ctx, u64 hop_addr, u64 vaddr)
115 {
116         return get_hopN_pte_addr(ctx, hop_addr, vaddr, HOP0_MASK, HOP0_SHIFT);
117 }
118
119 static inline u64 get_hop1_pte_addr(struct hl_ctx *ctx, u64 hop_addr, u64 vaddr)
120 {
121         return get_hopN_pte_addr(ctx, hop_addr, vaddr, HOP1_MASK, HOP1_SHIFT);
122 }
123
124 static inline u64 get_hop2_pte_addr(struct hl_ctx *ctx, u64 hop_addr, u64 vaddr)
125 {
126         return get_hopN_pte_addr(ctx, hop_addr, vaddr, HOP2_MASK, HOP2_SHIFT);
127 }
128
129 static inline u64 get_hop3_pte_addr(struct hl_ctx *ctx, u64 hop_addr, u64 vaddr)
130 {
131         return get_hopN_pte_addr(ctx, hop_addr, vaddr, HOP3_MASK, HOP3_SHIFT);
132 }
133
134 static inline u64 get_hop4_pte_addr(struct hl_ctx *ctx, u64 hop_addr, u64 vaddr)
135 {
136         return get_hopN_pte_addr(ctx, hop_addr, vaddr, HOP4_MASK, HOP4_SHIFT);
137 }
138
139 static inline u64 get_next_hop_addr(u64 curr_pte)
140 {
141         if (curr_pte & PAGE_PRESENT_MASK)
142                 return curr_pte & PHYS_ADDR_MASK;
143         else
144                 return ULLONG_MAX;
145 }
146
147 static inline u64 get_alloc_next_hop_addr(struct hl_ctx *ctx, u64 curr_pte,
148                                                 bool *is_new_hop)
149 {
150         u64 hop_addr = get_next_hop_addr(curr_pte);
151
152         if (hop_addr == ULLONG_MAX) {
153                 hop_addr = alloc_hop(ctx);
154                 *is_new_hop = (hop_addr != ULLONG_MAX);
155         }
156
157         return hop_addr;
158 }
159
160 /*
161  * hl_mmu_init - init the mmu module
162  *
163  * @hdev: pointer to the habanalabs device structure
164  *
165  * This function does the following:
166  * - Allocate max_asid zeroed hop0 pgts so no mapping is available
167  * - Enable mmu in hw
168  * - Invalidate the mmu cache
169  * - Create a pool of pages for pgts
170  * - Returns 0 on success
171  *
172  * This function depends on DMA QMAN to be working!
173  */
174 int hl_mmu_init(struct hl_device *hdev)
175 {
176         struct asic_fixed_properties *prop = &hdev->asic_prop;
177         int rc;
178
179         if (!hdev->mmu_enable)
180                 return 0;
181
182         /* MMU HW init was already done in device hw_init() */
183
184         mutex_init(&hdev->mmu_cache_lock);
185
186         hdev->mmu_pgt_pool =
187                         gen_pool_create(__ffs(prop->mmu_hop_table_size), -1);
188
189         if (!hdev->mmu_pgt_pool) {
190                 dev_err(hdev->dev, "Failed to create page gen pool\n");
191                 rc = -ENOMEM;
192                 goto err_pool_create;
193         }
194
195         rc = gen_pool_add(hdev->mmu_pgt_pool, prop->mmu_pgt_addr +
196                         prop->mmu_hop0_tables_total_size,
197                         prop->mmu_pgt_size - prop->mmu_hop0_tables_total_size,
198                         -1);
199         if (rc) {
200                 dev_err(hdev->dev, "Failed to add memory to page gen pool\n");
201                 goto err_pool_add;
202         }
203
204         return 0;
205
206 err_pool_add:
207         gen_pool_destroy(hdev->mmu_pgt_pool);
208 err_pool_create:
209         mutex_destroy(&hdev->mmu_cache_lock);
210
211         return rc;
212 }
213
214 /*
215  * hl_mmu_fini - release the mmu module.
216  *
217  * @hdev: pointer to the habanalabs device structure
218  *
219  * This function does the following:
220  * - Disable mmu in hw
221  * - free the pgts pool
222  *
223  * All ctxs should be freed before calling this func
224  */
225 void hl_mmu_fini(struct hl_device *hdev)
226 {
227         if (!hdev->mmu_enable)
228                 return;
229
230         gen_pool_destroy(hdev->mmu_pgt_pool);
231
232         mutex_destroy(&hdev->mmu_cache_lock);
233
234         /* MMU HW fini will be done in device hw_fini() */
235 }
236
237 /**
238  * hl_mmu_ctx_init() - initialize a context for using the MMU module.
239  * @ctx: pointer to the context structure to initialize.
240  *
241  * Initialize a mutex to protect the concurrent mapping flow, a hash to hold all
242  * page tables hops related to this context and an optional DRAM default page
243  * mapping.
244  * Return: 0 on success, non-zero otherwise.
245  */
246 int hl_mmu_ctx_init(struct hl_ctx *ctx)
247 {
248         struct hl_device *hdev = ctx->hdev;
249         struct asic_fixed_properties *prop = &hdev->asic_prop;
250         u64 num_of_hop3, total_hops, hop1_addr, hop2_addr, hop2_pte_addr,
251                 hop3_pte_addr, pte_val;
252         int rc, i, j, hop3_allocated = 0;
253
254         if (!hdev->mmu_enable)
255                 return 0;
256
257         mutex_init(&ctx->mmu_lock);
258         hash_init(ctx->mmu_hash);
259
260         if (!hdev->dram_supports_virtual_memory ||
261                         !hdev->dram_default_page_mapping)
262                 return 0;
263
264         num_of_hop3 = prop->dram_size_for_default_page_mapping;
265         do_div(num_of_hop3, prop->dram_page_size);
266         do_div(num_of_hop3, PTE_ENTRIES_IN_HOP);
267
268         /* add hop1 and hop2 */
269         total_hops = num_of_hop3 + 2;
270
271         ctx->dram_default_hops = kzalloc(HL_PTE_SIZE * total_hops,  GFP_KERNEL);
272         if (!ctx->dram_default_hops) {
273                 rc = -ENOMEM;
274                 goto alloc_err;
275         }
276
277         hop1_addr = alloc_hop(ctx);
278         if (hop1_addr == ULLONG_MAX) {
279                 dev_err(hdev->dev, "failed to alloc hop 1\n");
280                 rc = -ENOMEM;
281                 goto hop1_err;
282         }
283
284         ctx->dram_default_hops[total_hops - 1] = hop1_addr;
285
286         hop2_addr = alloc_hop(ctx);
287         if (hop2_addr == ULLONG_MAX) {
288                 dev_err(hdev->dev, "failed to alloc hop 2\n");
289                 rc = -ENOMEM;
290                 goto hop2_err;
291         }
292
293         ctx->dram_default_hops[total_hops - 2] = hop2_addr;
294
295         for (i = 0 ; i < num_of_hop3 ; i++) {
296                 ctx->dram_default_hops[i] = alloc_hop(ctx);
297                 if (ctx->dram_default_hops[i] == ULLONG_MAX) {
298                         dev_err(hdev->dev, "failed to alloc hop 3, i: %d\n", i);
299                         rc = -ENOMEM;
300                         goto hop3_err;
301                 }
302                 hop3_allocated++;
303         }
304
305         /* need only pte 0 in hops 0 and 1 */
306         pte_val = (hop1_addr & PTE_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
307         hdev->asic_funcs->write_pte(hdev, get_hop0_addr(ctx), pte_val);
308
309         pte_val = (hop2_addr & PTE_PHYS_ADDR_MASK) | PAGE_PRESENT_MASK;
310         hdev->asic_funcs->write_pte(hdev, hop1_addr, pte_val);
311         get_pte(ctx, hop1_addr);
312
313         hop2_pte_addr = hop2_addr;
314         for (i = 0 ; i < num_of_hop3 ; i++) {
315                 pte_val = (ctx->dram_default_hops[i] & PTE_PHYS_ADDR_MASK) |
316                                 PAGE_PRESENT_MASK;
317                 hdev->asic_funcs->write_pte(hdev, hop2_pte_addr, pte_val);
318                 get_pte(ctx, hop2_addr);
319                 hop2_pte_addr += HL_PTE_SIZE;
320         }
321
322         pte_val = (prop->mmu_dram_default_page_addr & PTE_PHYS_ADDR_MASK) |
323                         LAST_MASK | PAGE_PRESENT_MASK;
324
325         for (i = 0 ; i < num_of_hop3 ; i++) {
326                 hop3_pte_addr = ctx->dram_default_hops[i];
327                 for (j = 0 ; j < PTE_ENTRIES_IN_HOP ; j++) {
328                         hdev->asic_funcs->write_pte(hdev, hop3_pte_addr,
329                                         pte_val);
330                         get_pte(ctx, ctx->dram_default_hops[i]);
331                         hop3_pte_addr += HL_PTE_SIZE;
332                 }
333         }
334
335         /* flush all writes to reach PCI */
336         mb();
337         hdev->asic_funcs->read_pte(hdev, hop2_addr);
338
339         return 0;
340
341 hop3_err:
342         for (i = 0 ; i < hop3_allocated ; i++)
343                 free_hop(ctx, ctx->dram_default_hops[i]);
344         free_hop(ctx, hop2_addr);
345 hop2_err:
346         free_hop(ctx, hop1_addr);
347 hop1_err:
348         kfree(ctx->dram_default_hops);
349 alloc_err:
350         mutex_destroy(&ctx->mmu_lock);
351
352         return rc;
353 }
354
355 /*
356  * hl_mmu_ctx_fini - disable a ctx from using the mmu module
357  *
358  * @ctx: pointer to the context structure
359  *
360  * This function does the following:
361  * - Free any pgts which were not freed yet
362  * - Free the mutex
363  * - Free DRAM default page mapping hops
364  */
365 void hl_mmu_ctx_fini(struct hl_ctx *ctx)
366 {
367         struct hl_device *hdev = ctx->hdev;
368         struct asic_fixed_properties *prop = &hdev->asic_prop;
369         struct pgt_info *pgt_info;
370         struct hlist_node *tmp;
371         u64 num_of_hop3, total_hops, hop1_addr, hop2_addr, hop2_pte_addr,
372                 hop3_pte_addr;
373         int i, j;
374
375         if (!ctx->hdev->mmu_enable)
376                 return;
377
378         if (hdev->dram_supports_virtual_memory &&
379                         hdev->dram_default_page_mapping) {
380
381                 num_of_hop3 = prop->dram_size_for_default_page_mapping;
382                 do_div(num_of_hop3, prop->dram_page_size);
383                 do_div(num_of_hop3, PTE_ENTRIES_IN_HOP);
384
385                 /* add hop1 and hop2 */
386                 total_hops = num_of_hop3 + 2;
387                 hop1_addr = ctx->dram_default_hops[total_hops - 1];
388                 hop2_addr = ctx->dram_default_hops[total_hops - 2];
389
390                 for (i = 0 ; i < num_of_hop3 ; i++) {
391                         hop3_pte_addr = ctx->dram_default_hops[i];
392                         for (j = 0 ; j < PTE_ENTRIES_IN_HOP ; j++) {
393                                 clear_pte(hdev, hop3_pte_addr);
394                                 put_pte(ctx, ctx->dram_default_hops[i]);
395                                 hop3_pte_addr += HL_PTE_SIZE;
396                         }
397                 }
398
399                 hop2_pte_addr = hop2_addr;
400                 for (i = 0 ; i < num_of_hop3 ; i++) {
401                         clear_pte(hdev, hop2_pte_addr);
402                         put_pte(ctx, hop2_addr);
403                         hop2_pte_addr += HL_PTE_SIZE;
404                 }
405
406                 clear_pte(hdev, hop1_addr);
407                 put_pte(ctx, hop1_addr);
408                 clear_pte(hdev, get_hop0_addr(ctx));
409
410                 kfree(ctx->dram_default_hops);
411
412                 /* flush all writes to reach PCI */
413                 mb();
414                 hdev->asic_funcs->read_pte(hdev, hop2_addr);
415         }
416
417         if (!hash_empty(ctx->mmu_hash))
418                 dev_err(hdev->dev, "ctx is freed while it has pgts in use\n");
419
420         hash_for_each_safe(ctx->mmu_hash, i, tmp, pgt_info, node) {
421                 dev_err(hdev->dev,
422                         "pgt_info of addr 0x%llx of asid %d was not destroyed, num_ptes: %d\n",
423                         pgt_info->addr, ctx->asid, pgt_info->num_of_ptes);
424                 free_hop(ctx, pgt_info->addr);
425         }
426
427         mutex_destroy(&ctx->mmu_lock);
428 }
429
430 static int _hl_mmu_unmap(struct hl_ctx *ctx, u64 virt_addr)
431 {
432         struct hl_device *hdev = ctx->hdev;
433         struct asic_fixed_properties *prop = &hdev->asic_prop;
434         u64 hop0_addr = 0, hop0_pte_addr = 0,
435                 hop1_addr = 0, hop1_pte_addr = 0,
436                 hop2_addr = 0, hop2_pte_addr = 0,
437                 hop3_addr = 0, hop3_pte_addr = 0,
438                 hop4_addr = 0, hop4_pte_addr = 0,
439                 curr_pte;
440         int clear_hop3 = 1;
441         bool is_dram_addr, is_huge, is_dram_default_page_mapping;
442
443         is_dram_addr = hl_mem_area_inside_range(virt_addr, PAGE_SIZE_2MB,
444                                 prop->va_space_dram_start_address,
445                                 prop->va_space_dram_end_address);
446
447         hop0_addr = get_hop0_addr(ctx);
448
449         hop0_pte_addr = get_hop0_pte_addr(ctx, hop0_addr, virt_addr);
450
451         curr_pte = hdev->asic_funcs->read_pte(hdev, hop0_pte_addr);
452
453         hop1_addr = get_next_hop_addr(curr_pte);
454
455         if (hop1_addr == ULLONG_MAX)
456                 goto not_mapped;
457
458         hop1_pte_addr = get_hop1_pte_addr(ctx, hop1_addr, virt_addr);
459
460         curr_pte = hdev->asic_funcs->read_pte(hdev, hop1_pte_addr);
461
462         hop2_addr = get_next_hop_addr(curr_pte);
463
464         if (hop2_addr == ULLONG_MAX)
465                 goto not_mapped;
466
467         hop2_pte_addr = get_hop2_pte_addr(ctx, hop2_addr, virt_addr);
468
469         curr_pte = hdev->asic_funcs->read_pte(hdev, hop2_pte_addr);
470
471         hop3_addr = get_next_hop_addr(curr_pte);
472
473         if (hop3_addr == ULLONG_MAX)
474                 goto not_mapped;
475
476         hop3_pte_addr = get_hop3_pte_addr(ctx, hop3_addr, virt_addr);
477
478         curr_pte = hdev->asic_funcs->read_pte(hdev, hop3_pte_addr);
479
480         is_huge = curr_pte & LAST_MASK;
481
482         if (is_dram_addr && !is_huge) {
483                 dev_err(hdev->dev,
484                                 "DRAM unmapping should use huge pages only\n");
485                 return -EFAULT;
486         }
487
488         is_dram_default_page_mapping =
489                         hdev->dram_default_page_mapping && is_dram_addr;
490
491         if (!is_huge) {
492                 hop4_addr = get_next_hop_addr(curr_pte);
493
494                 if (hop4_addr == ULLONG_MAX)
495                         goto not_mapped;
496
497                 hop4_pte_addr = get_hop4_pte_addr(ctx, hop4_addr, virt_addr);
498
499                 curr_pte = hdev->asic_funcs->read_pte(hdev, hop4_pte_addr);
500
501                 clear_hop3 = 0;
502         }
503
504         if (is_dram_default_page_mapping) {
505                 u64 zero_pte = (prop->mmu_dram_default_page_addr &
506                                 PTE_PHYS_ADDR_MASK) | LAST_MASK |
507                                         PAGE_PRESENT_MASK;
508                 if (curr_pte == zero_pte) {
509                         dev_err(hdev->dev,
510                                 "DRAM: hop3 PTE points to zero page, can't unmap, va: 0x%llx\n",
511                                         virt_addr);
512                         goto not_mapped;
513                 }
514
515                 if (!(curr_pte & PAGE_PRESENT_MASK)) {
516                         dev_err(hdev->dev,
517                                 "DRAM: hop3 PTE is cleared! can't unmap, va: 0x%llx\n",
518                                         virt_addr);
519                         goto not_mapped;
520                 }
521
522                 hdev->asic_funcs->write_pte(hdev, hop3_pte_addr, zero_pte);
523                 put_pte(ctx, hop3_addr);
524         } else {
525                 if (!(curr_pte & PAGE_PRESENT_MASK))
526                         goto not_mapped;
527
528                 clear_pte(hdev, hop4_addr ? hop4_pte_addr : hop3_pte_addr);
529
530                 if (hop4_addr && !put_pte(ctx, hop4_addr))
531                         clear_hop3 = 1;
532
533                 if (!clear_hop3)
534                         goto flush;
535                 clear_pte(hdev, hop3_pte_addr);
536
537                 if (put_pte(ctx, hop3_addr))
538                         goto flush;
539                 clear_pte(hdev, hop2_pte_addr);
540
541                 if (put_pte(ctx, hop2_addr))
542                         goto flush;
543                 clear_pte(hdev, hop1_pte_addr);
544
545                 if (put_pte(ctx, hop1_addr))
546                         goto flush;
547                 clear_pte(hdev, hop0_pte_addr);
548         }
549
550 flush:
551         /* flush all writes from all cores to reach PCI */
552         mb();
553
554         hdev->asic_funcs->read_pte(hdev,
555                                 hop4_addr ? hop4_pte_addr : hop3_pte_addr);
556
557         return 0;
558
559 not_mapped:
560         dev_err(hdev->dev, "virt addr 0x%llx is not mapped to phys addr\n",
561                 virt_addr);
562
563         return -EINVAL;
564 }
565
566 /*
567  * hl_mmu_unmap - unmaps a virtual addr
568  *
569  * @ctx: pointer to the context structure
570  * @virt_addr: virt addr to map from
571  * @page_size: size of the page to unmap
572  *
573  * This function does the following:
574  * - Check that the virt addr is mapped
575  * - Unmap the virt addr and frees pgts if possible
576  * - Returns 0 on success, -EINVAL if the given addr is not mapped
577  *
578  * Because this function changes the page tables in the device and because it
579  * changes the MMU hash, it must be protected by a lock.
580  * However, because it maps only a single page, the lock should be implemented
581  * in a higher level in order to protect the entire mapping of the memory area
582  */
583 int hl_mmu_unmap(struct hl_ctx *ctx, u64 virt_addr, u32 page_size)
584 {
585         struct hl_device *hdev = ctx->hdev;
586         u64 real_virt_addr;
587         u32 real_page_size, npages;
588         int i, rc;
589
590         if (!hdev->mmu_enable)
591                 return 0;
592
593         /*
594          * The H/W handles mapping of 4KB/2MB page. Hence if the host page size
595          * is bigger, we break it to sub-pages and unmap them separately.
596          */
597         if ((page_size % PAGE_SIZE_2MB) == 0) {
598                 real_page_size = PAGE_SIZE_2MB;
599         } else if ((page_size % PAGE_SIZE_4KB) == 0) {
600                 real_page_size = PAGE_SIZE_4KB;
601         } else {
602                 dev_err(hdev->dev,
603                         "page size of %u is not 4KB nor 2MB aligned, can't unmap\n",
604                                 page_size);
605
606                 return -EFAULT;
607         }
608
609         npages = page_size / real_page_size;
610         real_virt_addr = virt_addr;
611
612         for (i = 0 ; i < npages ; i++) {
613                 rc = _hl_mmu_unmap(ctx, real_virt_addr);
614                 if (rc)
615                         return rc;
616
617                 real_virt_addr += real_page_size;
618         }
619
620         return 0;
621 }
622
623 static int _hl_mmu_map(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr,
624                 u32 page_size)
625 {
626         struct hl_device *hdev = ctx->hdev;
627         struct asic_fixed_properties *prop = &hdev->asic_prop;
628         u64 hop0_addr = 0, hop0_pte_addr = 0,
629                 hop1_addr = 0, hop1_pte_addr = 0,
630                 hop2_addr = 0, hop2_pte_addr = 0,
631                 hop3_addr = 0, hop3_pte_addr = 0,
632                 hop4_addr = 0, hop4_pte_addr = 0,
633                 curr_pte = 0;
634         bool hop1_new = false, hop2_new = false, hop3_new = false,
635                 hop4_new = false, is_huge, is_dram_addr,
636                 is_dram_default_page_mapping;
637         int rc = -ENOMEM;
638
639         /*
640          * This mapping function can map a 4KB/2MB page. For 2MB page there are
641          * only 3 hops rather than 4. Currently the DRAM allocation uses 2MB
642          * pages only but user memory could have been allocated with one of the
643          * two page sizes. Since this is a common code for all the three cases,
644          * we need this hugs page check.
645          */
646         is_huge = page_size == PAGE_SIZE_2MB;
647
648         is_dram_addr = hl_mem_area_inside_range(virt_addr, page_size,
649                                 prop->va_space_dram_start_address,
650                                 prop->va_space_dram_end_address);
651
652         if (is_dram_addr && !is_huge) {
653                 dev_err(hdev->dev, "DRAM mapping should use huge pages only\n");
654                 return -EFAULT;
655         }
656
657         is_dram_default_page_mapping =
658                         hdev->dram_default_page_mapping && is_dram_addr;
659
660         hop0_addr = get_hop0_addr(ctx);
661
662         hop0_pte_addr = get_hop0_pte_addr(ctx, hop0_addr, virt_addr);
663
664         curr_pte = hdev->asic_funcs->read_pte(hdev, hop0_pte_addr);
665
666         hop1_addr = get_alloc_next_hop_addr(ctx, curr_pte, &hop1_new);
667
668         if (hop1_addr == ULLONG_MAX)
669                 goto err;
670
671         hop1_pte_addr = get_hop1_pte_addr(ctx, hop1_addr, virt_addr);
672
673         curr_pte = hdev->asic_funcs->read_pte(hdev, hop1_pte_addr);
674
675         hop2_addr = get_alloc_next_hop_addr(ctx, curr_pte, &hop2_new);
676
677         if (hop2_addr == ULLONG_MAX)
678                 goto err;
679
680         hop2_pte_addr = get_hop2_pte_addr(ctx, hop2_addr, virt_addr);
681
682         curr_pte = hdev->asic_funcs->read_pte(hdev, hop2_pte_addr);
683
684         hop3_addr = get_alloc_next_hop_addr(ctx, curr_pte, &hop3_new);
685
686         if (hop3_addr == ULLONG_MAX)
687                 goto err;
688
689         hop3_pte_addr = get_hop3_pte_addr(ctx, hop3_addr, virt_addr);
690
691         curr_pte = hdev->asic_funcs->read_pte(hdev, hop3_pte_addr);
692
693         if (!is_huge) {
694                 hop4_addr = get_alloc_next_hop_addr(ctx, curr_pte, &hop4_new);
695
696                 if (hop4_addr == ULLONG_MAX)
697                         goto err;
698
699                 hop4_pte_addr = get_hop4_pte_addr(ctx, hop4_addr, virt_addr);
700
701                 curr_pte = hdev->asic_funcs->read_pte(hdev, hop4_pte_addr);
702         }
703
704         if (is_dram_default_page_mapping) {
705                 u64 zero_pte = (prop->mmu_dram_default_page_addr &
706                                         PTE_PHYS_ADDR_MASK) | LAST_MASK |
707                                                 PAGE_PRESENT_MASK;
708
709                 if (curr_pte != zero_pte) {
710                         dev_err(hdev->dev,
711                                 "DRAM: mapping already exists for virt_addr 0x%llx\n",
712                                         virt_addr);
713                         rc = -EINVAL;
714                         goto err;
715                 }
716
717                 if (hop1_new || hop2_new || hop3_new || hop4_new) {
718                         dev_err(hdev->dev,
719                                 "DRAM mapping should not allocate more hops\n");
720                         rc = -EFAULT;
721                         goto err;
722                 }
723         } else if (curr_pte & PAGE_PRESENT_MASK) {
724                 dev_err(hdev->dev,
725                                 "mapping already exists for virt_addr 0x%llx\n",
726                                         virt_addr);
727
728                 dev_dbg(hdev->dev, "hop0 pte: 0x%llx (0x%llx)\n",
729                                 hdev->asic_funcs->read_pte(hdev, hop0_pte_addr),
730                                 hop0_pte_addr);
731                 dev_dbg(hdev->dev, "hop1 pte: 0x%llx (0x%llx)\n",
732                                 hdev->asic_funcs->read_pte(hdev, hop1_pte_addr),
733                                 hop1_pte_addr);
734                 dev_dbg(hdev->dev, "hop2 pte: 0x%llx (0x%llx)\n",
735                                 hdev->asic_funcs->read_pte(hdev, hop2_pte_addr),
736                                 hop2_pte_addr);
737                 dev_dbg(hdev->dev, "hop3 pte: 0x%llx (0x%llx)\n",
738                                 hdev->asic_funcs->read_pte(hdev, hop3_pte_addr),
739                                 hop3_pte_addr);
740
741                 if (!is_huge)
742                         dev_dbg(hdev->dev, "hop4 pte: 0x%llx (0x%llx)\n",
743                                 hdev->asic_funcs->read_pte(hdev,
744                                                         hop4_pte_addr),
745                                                         hop4_pte_addr);
746
747                 rc = -EINVAL;
748                 goto err;
749         }
750
751         curr_pte = (phys_addr & PTE_PHYS_ADDR_MASK) | LAST_MASK
752                         | PAGE_PRESENT_MASK;
753
754         hdev->asic_funcs->write_pte(hdev,
755                                 is_huge ? hop3_pte_addr : hop4_pte_addr,
756                                 curr_pte);
757
758         if (hop1_new) {
759                 curr_pte = (hop1_addr & PTE_PHYS_ADDR_MASK) |
760                                 PAGE_PRESENT_MASK;
761                 ctx->hdev->asic_funcs->write_pte(ctx->hdev, hop0_pte_addr,
762                                 curr_pte);
763         }
764         if (hop2_new) {
765                 curr_pte = (hop2_addr & PTE_PHYS_ADDR_MASK) |
766                                 PAGE_PRESENT_MASK;
767                 ctx->hdev->asic_funcs->write_pte(ctx->hdev, hop1_pte_addr,
768                                 curr_pte);
769                 get_pte(ctx, hop1_addr);
770         }
771         if (hop3_new) {
772                 curr_pte = (hop3_addr & PTE_PHYS_ADDR_MASK) |
773                                 PAGE_PRESENT_MASK;
774                 ctx->hdev->asic_funcs->write_pte(ctx->hdev, hop2_pte_addr,
775                                 curr_pte);
776                 get_pte(ctx, hop2_addr);
777         }
778
779         if (!is_huge) {
780                 if (hop4_new) {
781                         curr_pte = (hop4_addr & PTE_PHYS_ADDR_MASK) |
782                                         PAGE_PRESENT_MASK;
783                         ctx->hdev->asic_funcs->write_pte(ctx->hdev,
784                                         hop3_pte_addr, curr_pte);
785                         get_pte(ctx, hop3_addr);
786                 }
787
788                 get_pte(ctx, hop4_addr);
789         } else {
790                 get_pte(ctx, hop3_addr);
791         }
792
793         /* flush all writes from all cores to reach PCI */
794         mb();
795
796         hdev->asic_funcs->read_pte(hdev,
797                                 is_huge ? hop3_pte_addr : hop4_pte_addr);
798
799         return 0;
800
801 err:
802         if (hop4_new)
803                 free_hop(ctx, hop4_addr);
804         if (hop3_new)
805                 free_hop(ctx, hop3_addr);
806         if (hop2_new)
807                 free_hop(ctx, hop2_addr);
808         if (hop1_new)
809                 free_hop(ctx, hop1_addr);
810
811         return rc;
812 }
813
814 /*
815  * hl_mmu_map - maps a virtual addr to physical addr
816  *
817  * @ctx: pointer to the context structure
818  * @virt_addr: virt addr to map from
819  * @phys_addr: phys addr to map to
820  * @page_size: physical page size
821  *
822  * This function does the following:
823  * - Check that the virt addr is not mapped
824  * - Allocate pgts as necessary in order to map the virt addr to the phys
825  * - Returns 0 on success, -EINVAL if addr is already mapped, or -ENOMEM.
826  *
827  * Because this function changes the page tables in the device and because it
828  * changes the MMU hash, it must be protected by a lock.
829  * However, because it maps only a single page, the lock should be implemented
830  * in a higher level in order to protect the entire mapping of the memory area
831  */
832 int hl_mmu_map(struct hl_ctx *ctx, u64 virt_addr, u64 phys_addr, u32 page_size)
833 {
834         struct hl_device *hdev = ctx->hdev;
835         u64 real_virt_addr;
836         u32 real_page_size, npages;
837         int i, rc, mapped_cnt = 0;
838
839         if (!hdev->mmu_enable)
840                 return 0;
841
842         /*
843          * The H/W handles mapping of 4KB/2MB page. Hence if the host page size
844          * is bigger, we break it to sub-pages and map them separately.
845          */
846         if ((page_size % PAGE_SIZE_2MB) == 0) {
847                 real_page_size = PAGE_SIZE_2MB;
848         } else if ((page_size % PAGE_SIZE_4KB) == 0) {
849                 real_page_size = PAGE_SIZE_4KB;
850         } else {
851                 dev_err(hdev->dev,
852                         "page size of %u is not 4KB nor 2MB aligned, can't map\n",
853                                 page_size);
854
855                 return -EFAULT;
856         }
857
858         npages = page_size / real_page_size;
859         real_virt_addr = virt_addr;
860
861         for (i = 0 ; i < npages ; i++) {
862                 rc = _hl_mmu_map(ctx, real_virt_addr, phys_addr,
863                                 real_page_size);
864                 if (rc)
865                         goto err;
866
867                 real_virt_addr += real_page_size;
868                 mapped_cnt++;
869         }
870
871         return 0;
872
873 err:
874         real_virt_addr = virt_addr;
875         for (i = 0 ; i < mapped_cnt ; i++) {
876                 if (_hl_mmu_unmap(ctx, real_virt_addr))
877                         dev_warn_ratelimited(hdev->dev,
878                                 "failed to unmap va: 0x%llx\n", real_virt_addr);
879
880                 real_virt_addr += real_page_size;
881         }
882
883         return rc;
884 }
885
886 /*
887  * hl_mmu_swap_out - marks all mapping of the given ctx as swapped out
888  *
889  * @ctx: pointer to the context structure
890  *
891  */
892 void hl_mmu_swap_out(struct hl_ctx *ctx)
893 {
894
895 }
896
897 /*
898  * hl_mmu_swap_in - marks all mapping of the given ctx as swapped in
899  *
900  * @ctx: pointer to the context structure
901  *
902  */
903 void hl_mmu_swap_in(struct hl_ctx *ctx)
904 {
905
906 }