diff --git a/drivers/virtio/virtio_mem.c b/drivers/virtio/virtio_mem.c index 7e83ed373e00..bef8ad6bf466 100644 --- a/drivers/virtio/virtio_mem.c +++ b/drivers/virtio/virtio_mem.c @@ -143,6 +143,8 @@ struct virtio_mem { * add_memory_driver_managed(). */ const char *resource_name; + /* Memory group identification. */ + int mgid; /* * We don't want to add too much memory if it's not getting onlined, @@ -626,8 +628,8 @@ static int virtio_mem_add_memory(struct virtio_mem *vm, uint64_t addr, addr + size - 1); /* Memory might get onlined immediately. */ atomic64_add(size, &vm->offline_size); - rc = add_memory_driver_managed(vm->nid, addr, size, vm->resource_name, - MHP_MERGE_RESOURCE); + rc = add_memory_driver_managed(vm->mgid, addr, size, vm->resource_name, + MHP_MERGE_RESOURCE | MHP_NID_IS_MGID); if (rc) { atomic64_sub(size, &vm->offline_size); dev_warn(&vm->vdev->dev, "adding memory failed: %d\n", rc); @@ -2569,6 +2571,7 @@ static bool virtio_mem_has_memory_added(struct virtio_mem *vm) static int virtio_mem_probe(struct virtio_device *vdev) { struct virtio_mem *vm; + uint64_t unit_pages; int rc; BUILD_BUG_ON(sizeof(struct virtio_mem_req) != 24); @@ -2603,6 +2606,16 @@ static int virtio_mem_probe(struct virtio_device *vdev) if (rc) goto out_del_vq; + /* use a single dynamic memory group to cover the whole memory device */ + if (vm->in_sbm) + unit_pages = PHYS_PFN(memory_block_size_bytes()); + else + unit_pages = PHYS_PFN(vm->bbm.bb_size); + rc = memory_group_register_dynamic(vm->nid, unit_pages); + if (rc < 0) + goto out_del_resource; + vm->mgid = rc; + /* * If we still have memory plugged, we have to unplug all memory first. * Registering our parent resource makes sure that this memory isn't @@ -2617,7 +2630,7 @@ static int virtio_mem_probe(struct virtio_device *vdev) vm->memory_notifier.notifier_call = virtio_mem_memory_notifier_cb; rc = register_memory_notifier(&vm->memory_notifier); if (rc) - goto out_del_resource; + goto out_unreg_group; rc = register_virtio_mem_device(vm); if (rc) goto out_unreg_mem; @@ -2631,6 +2644,8 @@ static int virtio_mem_probe(struct virtio_device *vdev) return 0; out_unreg_mem: unregister_memory_notifier(&vm->memory_notifier); +out_unreg_group: + memory_group_unregister(vm->mgid); out_del_resource: virtio_mem_delete_resource(vm); out_del_vq: @@ -2695,6 +2710,7 @@ static void virtio_mem_remove(struct virtio_device *vdev) } else { virtio_mem_delete_resource(vm); kfree_const(vm->resource_name); + memory_group_unregister(vm->mgid); } /* remove all tracking data - no locking needed */