diff options
Diffstat (limited to 'drivers/edgetpu/edgetpu-device-group.c')
-rw-r--r-- | drivers/edgetpu/edgetpu-device-group.c | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/drivers/edgetpu/edgetpu-device-group.c b/drivers/edgetpu/edgetpu-device-group.c index 49c672d..86ef9d2 100644 --- a/drivers/edgetpu/edgetpu-device-group.c +++ b/drivers/edgetpu/edgetpu-device-group.c @@ -21,6 +21,7 @@ #include <linux/seq_file.h> #include <linux/slab.h> #include <linux/spinlock.h> +#include <linux/uaccess.h> #include <linux/uidgid.h> #include "edgetpu-async.h" @@ -1168,10 +1169,16 @@ static struct page **edgetpu_pin_user_pages(struct edgetpu_device_group *group, if (size == 0) return ERR_PTR(-EINVAL); + if (!access_ok((const void *)host_addr, size)) { + etdev_err(etdev, "invalid address range in buffer map request"); + return ERR_PTR(-EFAULT); + } offset = host_addr & (PAGE_SIZE - 1); - /* overflow check */ - if (unlikely((size + offset) / PAGE_SIZE >= UINT_MAX - 1 || size + offset < size)) - return ERR_PTR(-ENOMEM); + /* overflow check (should also be caught by access_ok) */ + if (unlikely((size + offset) / PAGE_SIZE >= UINT_MAX - 1 || size + offset < size)) { + etdev_err(etdev, "address overflow in buffer map request"); + return ERR_PTR(-EFAULT); + } num_pages = DIV_ROUND_UP((size + offset), PAGE_SIZE); etdev_dbg(etdev, "%s: hostaddr=%#llx pages=%u", __func__, host_addr, num_pages); /* @@ -1204,10 +1211,20 @@ static struct page **edgetpu_pin_user_pages(struct edgetpu_device_group *group, *pnum_pages = num_pages; return pages; } + if (ret == -EFAULT && !*preadonly) { + foll_flags &= ~FOLL_WRITE; + *preadonly = true; + ret = pin_user_pages_fast(host_addr & PAGE_MASK, num_pages, + foll_flags, pages); + } if (ret < 0) { etdev_dbg(etdev, "pin_user_pages failed %u:%pK-%u: %d", group->workload_id, (void *)host_addr, num_pages, ret); + if (ret == -EFAULT) + etdev_err(etdev, + "bad address locking %u pages for %s", + num_pages, *preadonly ? "read" : "write"); if (ret != -ENOMEM) { num_pages = 0; goto error; @@ -1236,12 +1253,11 @@ static struct page **edgetpu_pin_user_pages(struct edgetpu_device_group *group, etdev_dbg(etdev, "pin_user_pages failed %u:%pK-%u: %d", group->workload_id, (void *)host_addr, num_pages, ret); - num_pages = 0; - if (ret == -ENOMEM) etdev_err(etdev, "system out of memory locking %u pages", num_pages); + num_pages = 0; goto error; } if (ret < num_pages) { |