diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c index 8d4a43e6aa46..6e70ec3ea796 100644 --- a/drivers/xen/pvcalls-front.c +++ b/drivers/xen/pvcalls-front.c @@ -70,6 +70,13 @@ struct sock_mapping { wait_queue_head_t inflight_conn_req; } active; + struct { + /* Socket status */ +#define PVCALLS_STATUS_UNINITALIZED 0 +#define PVCALLS_STATUS_BIND 1 +#define PVCALLS_STATUS_LISTEN 2 + uint8_t status; + } passive; }; }; @@ -346,6 +353,65 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, return ret; } +int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + struct pvcalls_bedata *bedata; + struct sock_mapping *map = NULL; + struct xen_pvcalls_request *req; + int notify, req_id, ret; + + if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) + return -EOPNOTSUPP; + + pvcalls_enter(); + if (!pvcalls_front_dev) { + pvcalls_exit(); + return -ENOTCONN; + } + bedata = dev_get_drvdata(&pvcalls_front_dev->dev); + + map = (struct sock_mapping *) sock->sk->sk_send_head; + if (map == NULL) { + pvcalls_exit(); + return -ENOTSOCK; + } + + spin_lock(&bedata->socket_lock); + ret = get_request(bedata, &req_id); + if (ret < 0) { + spin_unlock(&bedata->socket_lock); + pvcalls_exit(); + return ret; + } + req = RING_GET_REQUEST(&bedata->ring, req_id); + req->req_id = req_id; + map->sock = sock; + req->cmd = PVCALLS_BIND; + req->u.bind.id = (uintptr_t)map; + memcpy(req->u.bind.addr, addr, sizeof(*addr)); + req->u.bind.len = addr_len; + + map->active_socket = false; + + bedata->ring.req_prod_pvt++; + RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); + spin_unlock(&bedata->socket_lock); + if (notify) + notify_remote_via_irq(bedata->irq); + + wait_event(bedata->inflight_req, + READ_ONCE(bedata->rsp[req_id].req_id) == req_id); + + /* read req_id, then the content */ + smp_rmb(); + ret = bedata->rsp[req_id].ret; + bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; + + map->passive.status = PVCALLS_STATUS_BIND; + pvcalls_exit(); + return 0; +} + static const struct xenbus_device_id pvcalls_front_ids[] = { { "pvcalls" }, { "" } diff --git a/drivers/xen/pvcalls-front.h b/drivers/xen/pvcalls-front.h index 63b0417c31d3..8b0a2749d4b8 100644 --- a/drivers/xen/pvcalls-front.h +++ b/drivers/xen/pvcalls-front.h @@ -6,5 +6,8 @@ int pvcalls_front_socket(struct socket *sock); int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, int addr_len, int flags); +int pvcalls_front_bind(struct socket *sock, + struct sockaddr *addr, + int addr_len); #endif