diff --git a/arch/lkl/kernel/setup.c b/arch/lkl/kernel/setup.c index 3bce11e0b97bf2..51b09fe6fde21e 100644 --- a/arch/lkl/kernel/setup.c +++ b/arch/lkl/kernel/setup.c @@ -147,7 +147,10 @@ void arch_cpu_idle(void) { if (halt) { threads_cleanup(); - free_mem(); + /* TODO(pscollins): If we free here, it causes a + * segfault because the tx/rx threads are still + * running in parallel. */ + /* free_mem(); */ lkl_ops->sem_up(halt_sem); lkl_ops->thread_exit(); } diff --git a/tools/lkl/lib/virtio.c b/tools/lkl/lib/virtio.c index f88bace883ca30..2332487971b593 100644 --- a/tools/lkl/lib/virtio.c +++ b/tools/lkl/lib/virtio.c @@ -131,6 +131,23 @@ static int virtio_process_one(struct virtio_dev *dev, struct virtio_queue *q, return 0; } +/* NB: we can enter this function two different ways in the case of + * netdevs --- either through a tx/rx thread poll (which the LKL + * scheduler knows nothing about) or through virtio_write called + * inside an interrupt handler, so to be safe, it's not enough to + * synchronize only the tx/rx polling threads. + * + * At the moment, it seems like only netdevs require the + * synchronization we do here (i.e. locking around operations on a + * particular virtqueue, with dev->ops->acquire_queue), since they + * have these two different entry points, one of which isn't managed + * by the LKL scheduler. So only devs corresponding to netdevs will + * have non-NULL acquire/release_queue. + * + * In the future, this may change. If you see errors thrown in virtio + * driver code by block/console devices, you should be suspicious of + * the synchronization going on here. + */ void virtio_process_queue(struct virtio_dev *dev, uint32_t qidx) { struct virtio_queue *q = &dev->queue[qidx]; @@ -138,12 +155,18 @@ void virtio_process_queue(struct virtio_dev *dev, uint32_t qidx) if (!q->ready) return; + if (dev->ops->acquire_queue) + dev->ops->acquire_queue(dev, qidx); + virtio_set_avail_event(q, q->avail->idx); while (q->last_avail_idx != le16toh(q->avail->idx)) { if (virtio_process_one(dev, q, q->last_avail_idx) < 0) - return; + break; } + + if (dev->ops->release_queue) + dev->ops->release_queue(dev, qidx); } static inline uint32_t virtio_read_device_features(struct virtio_dev *dev) diff --git a/tools/lkl/lib/virtio.h b/tools/lkl/lib/virtio.h index 6873811b942df9..9154ec5f95262a 100644 --- a/tools/lkl/lib/virtio.h +++ b/tools/lkl/lib/virtio.h @@ -23,6 +23,11 @@ struct virtio_dev_ops { * virtio_process_queue at a later time. */ int (*enqueue)(struct virtio_dev *dev, struct virtio_req *req); + /* Acquire/release a lock on the specified queue. Only + * implemented by netdevs, all other devices have NULL + * acquire/release function pointers. */ + void (*acquire_queue)(struct virtio_dev *dev, int queue_idx); + void (*release_queue)(struct virtio_dev *dev, int queue_idx); }; struct virtio_queue { diff --git a/tools/lkl/lib/virtio_net.c b/tools/lkl/lib/virtio_net.c index ee5b3b9ac7acc3..46cb7102f3a0f8 100644 --- a/tools/lkl/lib/virtio_net.c +++ b/tools/lkl/lib/virtio_net.c @@ -5,6 +5,7 @@ #include +#define netdev_of(x) (container_of(x, struct virtio_net_dev, dev)) #define BIT(x) (1ULL << x) /* We always have 2 queues on a netdev: one for tx, one for rx. */ @@ -24,7 +25,6 @@ struct virtio_net_poll { struct virtio_net_dev *dev; - struct lkl_sem_t *sem; int event; }; @@ -34,6 +34,7 @@ struct virtio_net_dev { struct lkl_dev_net_ops *ops; union lkl_netdev nd; struct virtio_net_poll rx_poll, tx_poll; + struct lkl_mutex_t **queue_locks; }; static int net_check_features(struct virtio_dev *dev) @@ -44,6 +45,16 @@ static int net_check_features(struct virtio_dev *dev) return -LKL_EINVAL; } +static void net_acquire_queue(struct virtio_dev *dev, int queue_idx) +{ + lkl_host_ops.mutex_lock(netdev_of(dev)->queue_locks[queue_idx]); +} + +static void net_release_queue(struct virtio_dev *dev, int queue_idx) +{ + lkl_host_ops.mutex_unlock(netdev_of(dev)->queue_locks[queue_idx]); +} + static inline int is_rx_queue(struct virtio_dev *dev, struct virtio_queue *queue) { return &dev->queue[RX_QUEUE_IDX] == queue; @@ -62,7 +73,7 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req) void *buf; header = req->buf[0].addr; - net_dev = container_of(dev, struct virtio_net_dev, dev); + net_dev = netdev_of(dev); len = req->buf[0].len - sizeof(*header); buf = &header[1]; @@ -75,17 +86,13 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req) /* Pick which virtqueue to send the buffer(s) to */ if (is_tx_queue(dev, req->q)) { ret = net_dev->ops->tx(net_dev->nd, buf, len); - if (ret < 0) { - lkl_host_ops.sem_up(net_dev->tx_poll.sem); + if (ret < 0) return -1; - } } else if (is_rx_queue(dev, req->q)) { header->num_buffers = 1; ret = net_dev->ops->rx(net_dev->nd, buf, &len); - if (ret < 0) { - lkl_host_ops.sem_up(net_dev->rx_poll.sem); + if (ret < 0) return -1; - } } else { bad_request("tried to push on non-existent queue"); return -1; @@ -98,6 +105,8 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req) static struct virtio_dev_ops net_ops = { .check_features = net_check_features, .enqueue = net_enqueue, + .acquire_queue = net_acquire_queue, + .release_queue = net_release_queue, }; void poll_thread(void *arg) @@ -105,15 +114,47 @@ void poll_thread(void *arg) struct virtio_net_poll *np = (struct virtio_net_poll *)arg; int ret; + /* Synchronization is handled in virtio_process_queue */ while ((ret = np->dev->ops->poll(np->dev->nd, np->event)) >= 0) { if (ret & LKL_DEV_NET_POLL_RX) virtio_process_queue(&np->dev->dev, 0); if (ret & LKL_DEV_NET_POLL_TX) virtio_process_queue(&np->dev->dev, 1); - lkl_host_ops.sem_down(np->sem); } } +static void free_queue_locks(struct lkl_mutex_t **queues, int num_queues) +{ + int i = 0; + if (!queues) + return; + + for (i = 0; i < num_queues; i++) + lkl_host_ops.mem_free(queues[i]); + + lkl_host_ops.mem_free(queues); +} + +static struct lkl_mutex_t **init_queue_locks(int num_queues) +{ + int i; + struct lkl_mutex_t **ret = lkl_host_ops.mem_alloc( + sizeof(struct lkl_mutex_t*) * num_queues); + if (!ret) + return NULL; + + memset(ret, 0, sizeof(struct lkl_mutex_t*) * num_queues); + for (i = 0; i < num_queues; i++) { + ret[i] = lkl_host_ops.mutex_alloc(); + if (!ret[i]) { + free_queue_locks(ret, num_queues); + return NULL; + } + } + + return ret; +} + int lkl_netdev_add(union lkl_netdev nd, void *mac) { struct virtio_net_dev *dev; @@ -134,22 +175,26 @@ int lkl_netdev_add(union lkl_netdev nd, void *mac) dev->dev.ops = &net_ops; dev->ops = &lkl_dev_net_ops; dev->nd = nd; + dev->queue_locks = init_queue_locks(NUM_QUEUES); + + if (!dev->queue_locks) + goto out_free; if (mac) memcpy(dev->config.mac, mac, LKL_ETH_ALEN); dev->rx_poll.event = LKL_DEV_NET_POLL_RX; - dev->rx_poll.sem = lkl_host_ops.sem_alloc(0); dev->rx_poll.dev = dev; dev->tx_poll.event = LKL_DEV_NET_POLL_TX; - dev->tx_poll.sem = lkl_host_ops.sem_alloc(0); dev->tx_poll.dev = dev; - if (!dev->rx_poll.sem || !dev->tx_poll.sem) - goto out_free; + /* MUST match the number of queue locks we initialized. We + * could init the queues in virtio_dev_setup to help enforce + * this, but netdevs are the only flavor that need these + * locks, so it's better to do it here. */ + ret = virtio_dev_setup(&dev->dev, NUM_QUEUES, 32); - ret = virtio_dev_setup(&dev->dev, 2, 32); if (ret) goto out_free; @@ -167,10 +212,8 @@ int lkl_netdev_add(union lkl_netdev nd, void *mac) virtio_dev_cleanup(&dev->dev); out_free: - if (dev->rx_poll.sem) - lkl_host_ops.sem_free(dev->rx_poll.sem); - if (dev->tx_poll.sem) - lkl_host_ops.sem_free(dev->tx_poll.sem); + if (dev->queue_locks) + free_queue_locks(dev->queue_locks, NUM_QUEUES); lkl_host_ops.mem_free(dev); return ret;