Skip to content

Commit

Permalink
allow temp_memory_handler to allocate memory for multiple times (#161)
Browse files Browse the repository at this point in the history
fix to[ issue 76](#76), which allows temp_memory_handler to allocate memory for multiple times.

Authors:
  - https://github.com/linhu-nv

Approvers:
  - Chuang Zhu (https://github.com/chuangz0)

URL: #161
  • Loading branch information
linhu-nv authored Apr 30, 2024
1 parent f8cadcf commit b453a1d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
15 changes: 11 additions & 4 deletions cpp/src/wholememory_ops/temp_memory_handle.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,7 @@ class temp_memory_handle {
~temp_memory_handle() { free_memory(); }
void* device_malloc(size_t elt_count, wholememory_dtype_t data_type)
{
free_memory();
free_data();
wholememory_tensor_description_t tensor_description;
get_tensor_description(&tensor_description, elt_count, data_type);
ptr_ = temp_mem_fns_->malloc_fn(
Expand All @@ -40,7 +40,7 @@ class temp_memory_handle {
}
void* host_malloc(size_t elt_count, wholememory_dtype_t data_type)
{
free_memory();
free_data();
wholememory_tensor_description_t tensor_description;
get_tensor_description(&tensor_description, elt_count, data_type);
ptr_ = temp_mem_fns_->malloc_fn(
Expand All @@ -49,14 +49,21 @@ class temp_memory_handle {
}
void* pinned_malloc(size_t elt_count, wholememory_dtype_t data_type)
{
free_memory();
free_data();
wholememory_tensor_description_t tensor_description;
get_tensor_description(&tensor_description, elt_count, data_type);
ptr_ = temp_mem_fns_->malloc_fn(
&tensor_description, WHOLEMEMORY_MA_PINNED, memory_context_, temp_mem_fns_->global_context);
return ptr_;
}
[[nodiscard]] void* pointer() const { return ptr_; }
void free_data()
{
if (ptr_ != nullptr) {
temp_mem_fns_->free_fn(memory_context_, temp_mem_fns_->global_context);
ptr_ = nullptr;
}
}
void free_memory()
{
if (ptr_ != nullptr) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -75,6 +75,11 @@ def free(self):
torch_cpp_ext_lib.destroy_output_context(self.get_handle())
self.handle = 0

def free_data(self):
self.tensor = None
if torch_cpp_ext_loaded and self.get_handle() != 0:
torch_cpp_ext_lib.free_context_data(self.get_handle())


def torch_create_memory_context_env_fn(
global_context: TorchEmptyGlobalContext,
Expand Down Expand Up @@ -121,7 +126,7 @@ def torch_malloc_env_fn(
def torch_free_env_fn(
memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext
):
memory_context.free()
memory_context.free_data()


class ExtContextWrapper(object):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,4 +60,8 @@ void destroy_output_context(void* output_context) {
destroy_torch_memory_context_func(output_context, nullptr);
}

void free_context_data(void* output_context) {
torch_common_free_func(output_context, nullptr);
}

} // namespace wholegraph_torch
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/script.h>

Expand All @@ -24,6 +39,11 @@ void wrapped_destroy_output_context(int64_t output_context)
wholegraph_torch::destroy_output_context(reinterpret_cast<void*>(output_context));
}

void wrapped_free_context_data(int64_t output_context)
{
wholegraph_torch::free_context_data(reinterpret_cast<void*>(output_context), nullptr);
}

torch::Tensor get_torch_tensor_from_output_context(int64_t output_context)
{
auto* torch_output_context =
Expand All @@ -39,6 +59,7 @@ PYBIND11_MODULE(pylibwholegraph_torch_ext, m)
m.def("get_stream", &wrapped_get_stream, "Get current CUDA stream.");
m.def("create_output_context", &wrapped_create_output_context, "Create output memory context.");
m.def("destroy_output_context", &wrapped_destroy_output_context, "Destroy output memory context.");
m.def("free_context_data", &wrapped_free_context_data, "Free data in output memory context.");
m.def("get_tensor_from_context",
&get_torch_tensor_from_output_context,
"Get PyTorch Tensor from output memory context");
Expand Down

0 comments on commit b453a1d

Please sign in to comment.