Skip to content

Commit

Permalink
lkl: Fix race condition in virtio_net rx path
Browse files Browse the repository at this point in the history
As it stands now, the only synchronization done on the virtio net
devices is done in the virtio_net_poll struct (i.e. in the tx/rx polling
threads). This is insufficient --- virtio_write can be called in an
interrupt handler and call virtio_process_queue on a queue attached to a
netdev, which bypasses the synchronization in the virtio_net_poll struct
and eventually causes virtqueue_get_buf (in
drivers/virtio/virtio_ring.c) to fail, which hangs LKL.

This commit adds a mutex to protect each queue of the virtio_net_dev,
and pushes the synchronization into virtio_process_queue, which solves
the race and has more or less the same performance as the original
implementation.

Signed-off-by: Patrick Collins <pscollins@google.com>
  • Loading branch information
pscollins committed Feb 23, 2016
1 parent 8f7a7c2 commit 1f5a1af
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 20 deletions.
5 changes: 4 additions & 1 deletion arch/lkl/kernel/setup.c
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ void arch_cpu_idle(void)
{
if (halt) {
threads_cleanup();
free_mem();
/* TODO(pscollins): If we free here, it causes a
* segfault because the tx/rx threads are still
* running in parallel. */
/* free_mem(); */
lkl_ops->sem_up(halt_sem);
lkl_ops->thread_exit();
}
Expand Down
25 changes: 24 additions & 1 deletion tools/lkl/lib/virtio.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,42 @@ static int virtio_process_one(struct virtio_dev *dev, struct virtio_queue *q,
return 0;
}

/* NB: we can enter this function two different ways in the case of
* netdevs --- either through a tx/rx thread poll (which the LKL
* scheduler knows nothing about) or through virtio_write called
* inside an interrupt handler, so to be safe, it's not enough to
* synchronize only the tx/rx polling threads.
*
* At the moment, it seems like only netdevs require the
* synchronization we do here (i.e. locking around operations on a
* particular virtqueue, with dev->ops->acquire_queue), since they
* have these two different entry points, one of which isn't managed
* by the LKL scheduler. So only devs corresponding to netdevs will
* have non-NULL acquire/release_queue.
*
* In the future, this may change. If you see errors thrown in virtio
* driver code by block/console devices, you should be suspicious of
* the synchronization going on here.
*/
void virtio_process_queue(struct virtio_dev *dev, uint32_t qidx)
{
struct virtio_queue *q = &dev->queue[qidx];

if (!q->ready)
return;

if (dev->ops->acquire_queue)
dev->ops->acquire_queue(dev, qidx);

virtio_set_avail_event(q, q->avail->idx);

while (q->last_avail_idx != le16toh(q->avail->idx)) {
if (virtio_process_one(dev, q, q->last_avail_idx) < 0)
return;
break;
}

if (dev->ops->release_queue)
dev->ops->release_queue(dev, qidx);
}

static inline uint32_t virtio_read_device_features(struct virtio_dev *dev)
Expand Down
5 changes: 5 additions & 0 deletions tools/lkl/lib/virtio.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ struct virtio_dev_ops {
* virtio_process_queue at a later time.
*/
int (*enqueue)(struct virtio_dev *dev, struct virtio_req *req);
/* Acquire/release a lock on the specified queue. Only
* implemented by netdevs, all other devices have NULL
* acquire/release function pointers. */
void (*acquire_queue)(struct virtio_dev *dev, int queue_idx);
void (*release_queue)(struct virtio_dev *dev, int queue_idx);
};

struct virtio_queue {
Expand Down
79 changes: 61 additions & 18 deletions tools/lkl/lib/virtio_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <lkl/linux/virtio_net.h>

#define netdev_of(x) (container_of(x, struct virtio_net_dev, dev))
#define BIT(x) (1ULL << x)

/* We always have 2 queues on a netdev: one for tx, one for rx. */
Expand All @@ -24,7 +25,6 @@

struct virtio_net_poll {
struct virtio_net_dev *dev;
struct lkl_sem_t *sem;
int event;
};

Expand All @@ -34,6 +34,7 @@ struct virtio_net_dev {
struct lkl_dev_net_ops *ops;
union lkl_netdev nd;
struct virtio_net_poll rx_poll, tx_poll;
struct lkl_mutex_t **queue_locks;
};

static int net_check_features(struct virtio_dev *dev)
Expand All @@ -44,6 +45,16 @@ static int net_check_features(struct virtio_dev *dev)
return -LKL_EINVAL;
}

static void net_acquire_queue(struct virtio_dev *dev, int queue_idx)
{
lkl_host_ops.mutex_lock(netdev_of(dev)->queue_locks[queue_idx]);
}

static void net_release_queue(struct virtio_dev *dev, int queue_idx)
{
lkl_host_ops.mutex_unlock(netdev_of(dev)->queue_locks[queue_idx]);
}

static inline int is_rx_queue(struct virtio_dev *dev, struct virtio_queue *queue)
{
return &dev->queue[RX_QUEUE_IDX] == queue;
Expand All @@ -62,7 +73,7 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req)
void *buf;

header = req->buf[0].addr;
net_dev = container_of(dev, struct virtio_net_dev, dev);
net_dev = netdev_of(dev);
len = req->buf[0].len - sizeof(*header);

buf = &header[1];
Expand All @@ -75,17 +86,13 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req)
/* Pick which virtqueue to send the buffer(s) to */
if (is_tx_queue(dev, req->q)) {
ret = net_dev->ops->tx(net_dev->nd, buf, len);
if (ret < 0) {
lkl_host_ops.sem_up(net_dev->tx_poll.sem);
if (ret < 0)
return -1;
}
} else if (is_rx_queue(dev, req->q)) {
header->num_buffers = 1;
ret = net_dev->ops->rx(net_dev->nd, buf, &len);
if (ret < 0) {
lkl_host_ops.sem_up(net_dev->rx_poll.sem);
if (ret < 0)
return -1;
}
} else {
bad_request("tried to push on non-existent queue");
return -1;
Expand All @@ -98,22 +105,56 @@ static int net_enqueue(struct virtio_dev *dev, struct virtio_req *req)
static struct virtio_dev_ops net_ops = {
.check_features = net_check_features,
.enqueue = net_enqueue,
.acquire_queue = net_acquire_queue,
.release_queue = net_release_queue,
};

void poll_thread(void *arg)
{
struct virtio_net_poll *np = (struct virtio_net_poll *)arg;
int ret;

/* Synchronization is handled in virtio_process_queue */
while ((ret = np->dev->ops->poll(np->dev->nd, np->event)) >= 0) {
if (ret & LKL_DEV_NET_POLL_RX)
virtio_process_queue(&np->dev->dev, 0);
if (ret & LKL_DEV_NET_POLL_TX)
virtio_process_queue(&np->dev->dev, 1);
lkl_host_ops.sem_down(np->sem);
}
}

static void free_queue_locks(struct lkl_mutex_t **queues, int num_queues)
{
int i = 0;
if (!queues)
return;

for (i = 0; i < num_queues; i++)
lkl_host_ops.mem_free(queues[i]);

lkl_host_ops.mem_free(queues);
}

static struct lkl_mutex_t **init_queue_locks(int num_queues)
{
int i;
struct lkl_mutex_t **ret = lkl_host_ops.mem_alloc(
sizeof(struct lkl_mutex_t*) * num_queues);
if (!ret)
return NULL;

memset(ret, 0, sizeof(struct lkl_mutex_t*) * num_queues);
for (i = 0; i < num_queues; i++) {
ret[i] = lkl_host_ops.mutex_alloc();
if (!ret[i]) {
free_queue_locks(ret, num_queues);
return NULL;
}
}

return ret;
}

int lkl_netdev_add(union lkl_netdev nd, void *mac)
{
struct virtio_net_dev *dev;
Expand All @@ -134,22 +175,26 @@ int lkl_netdev_add(union lkl_netdev nd, void *mac)
dev->dev.ops = &net_ops;
dev->ops = &lkl_dev_net_ops;
dev->nd = nd;
dev->queue_locks = init_queue_locks(NUM_QUEUES);

if (!dev->queue_locks)
goto out_free;

if (mac)
memcpy(dev->config.mac, mac, LKL_ETH_ALEN);

dev->rx_poll.event = LKL_DEV_NET_POLL_RX;
dev->rx_poll.sem = lkl_host_ops.sem_alloc(0);
dev->rx_poll.dev = dev;

dev->tx_poll.event = LKL_DEV_NET_POLL_TX;
dev->tx_poll.sem = lkl_host_ops.sem_alloc(0);
dev->tx_poll.dev = dev;

if (!dev->rx_poll.sem || !dev->tx_poll.sem)
goto out_free;
/* MUST match the number of queue locks we initialized. We
* could init the queues in virtio_dev_setup to help enforce
* this, but netdevs are the only flavor that need these
* locks, so it's better to do it here. */
ret = virtio_dev_setup(&dev->dev, NUM_QUEUES, 32);

ret = virtio_dev_setup(&dev->dev, 2, 32);
if (ret)
goto out_free;

Expand All @@ -167,10 +212,8 @@ int lkl_netdev_add(union lkl_netdev nd, void *mac)
virtio_dev_cleanup(&dev->dev);

out_free:
if (dev->rx_poll.sem)
lkl_host_ops.sem_free(dev->rx_poll.sem);
if (dev->tx_poll.sem)
lkl_host_ops.sem_free(dev->tx_poll.sem);
if (dev->queue_locks)
free_queue_locks(dev->queue_locks, NUM_QUEUES);
lkl_host_ops.mem_free(dev);

return ret;
Expand Down

0 comments on commit 1f5a1af

Please sign in to comment.