memory_pool.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #include "memory_pool_impl.h"
  2. #include "core/cuda_helper.hpp"
  3. #include "core/utility.hpp"
  4. #include <boost/asio/io_context.hpp>
  5. #include <boost/asio/post.hpp>
  6. #include <spdlog/spdlog.h>
  7. #include <cuda.h>
  8. #include <algorithm>
  9. #include <ranges>
  10. using boost::asio::io_context;
  11. using boost::asio::post;
  12. memory_pool global_mp;
  13. void memory_pool::impl::reg_allocate(mem_info_type mem_info) {
  14. malloc_pool.emplace(mem_info.ptr, mem_info);
  15. }
  16. void *memory_pool::impl::try_reuse_host(size_t count) {
  17. auto iter = reuse_host_pool.lower_bound(count);
  18. if (iter == reuse_host_pool.end()) [[unlikely]] return nullptr;
  19. auto mem_info = iter->second;
  20. if (mem_info.count * reuse_threshold > count) [[unlikely]] return nullptr;
  21. reuse_host_pool.erase(iter);
  22. reg_allocate(mem_info);
  23. return mem_info.ptr;
  24. }
  25. void *memory_pool::impl::try_reuse_cuda(size_t count) {
  26. auto iter = reuse_cuda_pool.lower_bound(count);
  27. if (iter == reuse_cuda_pool.end()) [[unlikely]] return nullptr;
  28. auto mem_info = iter->second;
  29. if (mem_info.count * reuse_threshold > count) [[unlikely]] return nullptr;
  30. reuse_cuda_pool.erase(iter);
  31. reg_allocate(mem_info);
  32. return mem_info.ptr;
  33. }
  34. void *memory_pool::impl::direct_allocate_host(size_t count) {
  35. auto ptr = ::malloc(count);
  36. reg_allocate({.ptr = ptr, .loc = MEM_HOST, .lay = MEM_LINEAR, .count = count});
  37. return ptr;
  38. }
  39. void *memory_pool::impl::direct_allocate_cuda(size_t count) {
  40. void *ptr = nullptr;
  41. CUDA_API_CHECK(cudaMalloc(&ptr, count));
  42. reg_allocate({.ptr = ptr, .loc = MEM_CUDA, .lay = MEM_LINEAR, .count = count});
  43. return ptr;
  44. }
  45. void *memory_pool::impl::allocate_host(size_t count) {
  46. if (auto ptr = try_reuse_host(count);
  47. ptr != nullptr) [[likely]] {
  48. return ptr;
  49. }
  50. return direct_allocate_host(count);
  51. }
  52. void *memory_pool::impl::allocate_cuda(size_t count) {
  53. if (auto ptr = try_reuse_cuda(count);
  54. ptr != nullptr) [[likely]] {
  55. return ptr;
  56. }
  57. return direct_allocate_cuda(count);
  58. }
  59. void *memory_pool::impl::allocate(size_t count, memory_location mem_loc) {
  60. auto guard = std::lock_guard(mu);
  61. switch (mem_loc) {
  62. case MEM_HOST: {
  63. return allocate_host(count);
  64. }
  65. case MEM_CUDA: {
  66. return allocate_cuda(count);
  67. }
  68. }
  69. RET_ERROR_P;
  70. }
  71. void *memory_pool::impl::allocate_pitch(
  72. size_t width, size_t rows, memory_location mem_loc, size_t *pitch) {
  73. auto guard = std::lock_guard(mu);
  74. switch (mem_loc) {
  75. case MEM_HOST: {
  76. *pitch = width;
  77. return allocate_host(width * rows);
  78. }
  79. case MEM_CUDA: {
  80. if (width & 0x1F) { // next multiples of 32
  81. *pitch = (width + 0x20) & 0x1F;
  82. } else {
  83. *pitch = width;
  84. }
  85. return allocate_cuda(*pitch * rows);
  86. }
  87. }
  88. RET_ERROR_P;
  89. }
  90. cudaEvent_t memory_pool::impl::get_event(void *ptr) {
  91. auto guard = std::lock_guard(mu);
  92. auto iter = malloc_pool.lower_bound(ptr);
  93. assert(iter != malloc_pool.end());
  94. auto &mem_info = iter->second;
  95. assert((char *) ptr - (char *) mem_info.ptr < mem_info.count);
  96. if (mem_info.event == nullptr) [[unlikely]] {
  97. CUDA_API_CHECK(cudaEventCreate(&mem_info.event, cudaEventDisableTiming));
  98. }
  99. assert(mem_info.event != nullptr);
  100. return mem_info.event;
  101. }
  102. void memory_pool::impl::deallocate(void *ptr) {
  103. auto guard = std::lock_guard(mu);
  104. auto iter = malloc_pool.find(ptr);
  105. if (iter == malloc_pool.end()) {
  106. SPDLOG_WARN("Deallocate unknown pointer: {}.", fmt::ptr(ptr));
  107. return;
  108. }
  109. auto mem_info = iter->second;
  110. malloc_pool.erase(iter);
  111. switch (mem_info.loc) {
  112. case MEM_HOST: {
  113. reuse_host_pool.emplace(mem_info.count, mem_info);
  114. return;
  115. }
  116. case MEM_CUDA: {
  117. reuse_cuda_pool.emplace(mem_info.count, mem_info);
  118. return;
  119. }
  120. }
  121. RET_ERROR;
  122. }
  123. void memory_pool::impl::system_deallocate(mem_info_type mem_info) {
  124. switch (mem_info.loc) {
  125. case MEM_HOST: {
  126. ::free(mem_info.ptr);
  127. return;
  128. }
  129. case MEM_CUDA: {
  130. CUDA_API_CHECK(cudaFree(mem_info.ptr));
  131. return;
  132. }
  133. }
  134. RET_ERROR;
  135. }
  136. void memory_pool::impl::purge() {
  137. auto guard = std::lock_guard(mu);
  138. for (auto item: reuse_host_pool | std::views::values) {
  139. system_deallocate(item);
  140. }
  141. reuse_host_pool.clear();
  142. for (auto item: reuse_cuda_pool | std::views::values) {
  143. system_deallocate(item);
  144. }
  145. reuse_cuda_pool.clear();
  146. }
  147. void *memory_pool::allocate_impl(size_t count, memory_location mem_loc) {
  148. return pimpl->allocate(count, mem_loc);
  149. }
  150. void *memory_pool::allocate_pitch_impl(
  151. size_t width, size_t rows, memory_location mem_loc, size_t *pitch) {
  152. return pimpl->allocate_pitch(width, rows, mem_loc, pitch);
  153. }
  154. void memory_pool::record_create(void *ptr, smart_cuda_stream *stream) {
  155. if (stream == nullptr) return;
  156. auto event = pimpl->get_event(ptr);
  157. CUDA_API_CHECK(cudaEventRecord(event, stream->cuda));
  158. }
  159. void memory_pool::sync_create(void *ptr, smart_cuda_stream *stream) {
  160. auto event = pimpl->get_event(ptr);
  161. if (stream == nullptr) {
  162. CUDA_API_CHECK(cudaEventSynchronize(event));
  163. } else {
  164. CUDA_API_CHECK(cudaStreamWaitEvent(stream->cuda, event));
  165. }
  166. }
  167. void memory_pool::deallocate(void *ptr) {
  168. return pimpl->deallocate(ptr);
  169. }
  170. void memory_pool::purge() {
  171. pimpl->purge();
  172. }
  173. memory_pool::memory_pool()
  174. : pimpl(std::make_unique<impl>()) {}
  175. memory_pool::~memory_pool() = default;