process_kernels.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. #include "process_kernels.cuh"
  2. #include <cassert>
  3. #include <type_traits>
  4. // kernel templates
  5. template<typename OutT, typename ReduceFunc, uint16_t BlockSize>
  6. __device__ void warp_reduce(volatile OutT *s_buf, uint32_t tdx) {
  7. static_assert(std::is_fundamental_v<OutT>);
  8. if constexpr (BlockSize >= 64) {
  9. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 32]);
  10. }
  11. if constexpr (BlockSize >= 32) {
  12. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 16]);
  13. }
  14. if constexpr (BlockSize >= 16) {
  15. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 8]);
  16. }
  17. if constexpr (BlockSize >= 8) {
  18. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 4]);
  19. }
  20. if constexpr (BlockSize >= 4) {
  21. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 2]);
  22. }
  23. if constexpr (BlockSize >= 2) {
  24. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 1]);
  25. }
  26. }
  27. template<typename InT, typename OutT, typename UpdateFunc, typename ReduceFunc, OutT InitVal, uint16_t BlockSize>
  28. __global__ void reduce_any(InT *in, OutT *out, uint32_t n) {
  29. extern __shared__ int shmem[];
  30. auto s_buf = (OutT *) shmem;
  31. uint32_t tdx = threadIdx.x;
  32. uint32_t bkx = blockIdx.x;
  33. uint32_t grid_size = BlockSize * gridDim.x;
  34. OutT t_out = InitVal;
  35. // load per-thread data
  36. for (uint32_t i = bkx * blockDim.x + tdx;
  37. i < n;
  38. i += grid_size) {
  39. UpdateFunc::Op(&t_out, in[i]);
  40. }
  41. // update to shared memory
  42. s_buf[tdx] = t_out;
  43. __syncthreads();
  44. if constexpr (BlockSize >= 512) {
  45. if (tdx < 256) {
  46. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 256]);
  47. }
  48. __syncthreads();
  49. }
  50. if constexpr (BlockSize >= 256) {
  51. if (tdx < 128) {
  52. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 128]);
  53. }
  54. __syncthreads();
  55. }
  56. if constexpr (BlockSize >= 128) {
  57. if (tdx < 64) {
  58. ReduceFunc::Op(&s_buf[tdx], s_buf[tdx + 64]);
  59. }
  60. __syncthreads();
  61. }
  62. if (tdx < 32) {
  63. warp_reduce<OutT, ReduceFunc, BlockSize>(s_buf, tdx);
  64. }
  65. if (tdx == 0) {
  66. out[bkx] = s_buf[0];
  67. }
  68. }
  69. template<typename InT, typename OutT, typename Func>
  70. __global__ void elementwise_any(InT *in, OutT *out, uint32_t n) {
  71. uint32_t tdx = threadIdx.x;
  72. uint32_t bkx = blockIdx.x;
  73. uint32_t grid_size = blockDim.x * gridDim.x;
  74. for (uint32_t i = bkx * blockDim.x + tdx;
  75. i < n;
  76. i += grid_size) {
  77. Func::Op(&out[i], in[i]);
  78. }
  79. }
  80. template<typename InT, typename OutT, typename ExtT, typename Func>
  81. __global__ void elementwise_ext_any(InT *in, OutT *out, uint32_t n, ExtT *p_ext) {
  82. uint32_t tdx = threadIdx.x;
  83. uint32_t bkx = blockIdx.x;
  84. uint32_t grid_size = blockDim.x * gridDim.x;
  85. // load extra values
  86. ExtT ext = *p_ext;
  87. for (uint32_t i = bkx * blockDim.x + tdx;
  88. i < n;
  89. i += grid_size) {
  90. Func::Op(&out[i], in[i], ext);
  91. }
  92. }
  93. template<typename InT, typename OutT, typename UpdateFunc, typename ReduceFunc, OutT InitVal>
  94. void call_reduce_any_kernel(InT *in, OutT *out, uint32_t n,
  95. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  96. assert(n <= std::numeric_limits<uint32_t>::max());
  97. auto shmem_size = block_size * (1 + (block_size <= 32));
  98. auto shmem_length = shmem_size * sizeof(OutT);
  99. switch (block_size) {
  100. case 512: {
  101. constexpr uint16_t BlockSize = 512;
  102. auto reduce_func = reduce_any<InT, OutT, UpdateFunc, ReduceFunc, InitVal, BlockSize>;
  103. reduce_func<<<grid_dim, BlockSize, shmem_length, stream>>>(in, out, n);
  104. return;
  105. }
  106. case 256: {
  107. constexpr uint16_t BlockSize = 256;
  108. auto reduce_func = reduce_any<InT, OutT, UpdateFunc, ReduceFunc, InitVal, BlockSize>;
  109. reduce_func<<<grid_dim, BlockSize, shmem_length, stream>>>(in, out, n);
  110. return;
  111. }
  112. case 128: {
  113. constexpr uint16_t BlockSize = 128;
  114. auto reduce_func = reduce_any<InT, OutT, UpdateFunc, ReduceFunc, InitVal, BlockSize>;
  115. reduce_func<<<grid_dim, BlockSize, shmem_length, stream>>>(in, out, n);
  116. return;
  117. }
  118. default: {
  119. assert(false);
  120. }
  121. }
  122. }
  123. // result resides in out[0]
  124. template<typename InT, typename OutT, typename UpdateFunc, typename ReduceFunc, OutT InitVal>
  125. void call_reduce_any(InT *in, OutT *out, uint32_t n,
  126. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  127. { // first step
  128. auto helper_func = call_reduce_any_kernel<InT, OutT, UpdateFunc, ReduceFunc, InitVal>;
  129. helper_func(in, out, n, block_size, grid_dim, stream);
  130. }
  131. { // second step
  132. auto helper_func = call_reduce_any_kernel<OutT, OutT, ReduceFunc, ReduceFunc, InitVal>;
  133. helper_func(out, out, grid_dim, block_size, 1, stream);
  134. }
  135. }
  136. // working functions
  137. template<typename T>
  138. struct type_max_value {
  139. static constexpr T value = std::numeric_limits<T>::max();
  140. };
  141. template<typename T>
  142. struct reduce_max_func {
  143. static __device__ __forceinline__ void Op(volatile T *out, T val) {
  144. *out = max(*out, val);
  145. }
  146. };
  147. template<typename T>
  148. struct reduce_min_func {
  149. static __device__ __forceinline__ void Op(volatile T *out, T val) {
  150. *out = min(*out, val);
  151. }
  152. };
  153. template<typename T>
  154. struct reduce_sum_func {
  155. static __device__ __forceinline__ void Op(volatile T *out, T val) {
  156. *out = *out + val;
  157. }
  158. };
  159. template<typename T>
  160. struct update_log_sum_func {
  161. static constexpr T eps = (T) 1e-6;
  162. static __device__ __forceinline__ void Op(T *out, T val) {
  163. *out += log(val + eps);
  164. }
  165. };
  166. template<typename InT, typename OutT>
  167. struct rgb_extract_v_func { // Extract V value of HSV from RGB
  168. static __device__ __forceinline__ void Op(OutT *out, InT in) {
  169. if constexpr (std::is_floating_point_v<OutT>) {
  170. using InElemT = decltype(in.x);
  171. constexpr OutT factor = (OutT) 1 / type_max_value<InElemT>::value;
  172. *out = factor * max(max(in.x, in.y), in.z);
  173. } else {
  174. *out = max(max(in.x, in.y), in.z);
  175. }
  176. }
  177. };
  178. struct enhance_v_func {
  179. static __device__ __forceinline__ void Op(float *out, float in, enhance_coeff ext) {
  180. *out = ext.norm_factor * log(in / ext.log_avg + 1);
  181. }
  182. };
  183. template<typename ImgT>
  184. struct enhance_image_func {
  185. static __device__ __forceinline__ void Op(ImgT *p_out, ImgT in, enhance_coeff ext) {
  186. // convert RGB to HSV
  187. // https://www.rapidtables.com/convert/color/rgb-to-hsv.html
  188. using ImgElemT = decltype(in.x);
  189. static_assert(std::is_integral_v<ImgElemT>);
  190. ImgElemT c_max = max(max(in.x, in.y), in.z);
  191. ImgElemT c_min = min(min(in.x, in.y), in.z);
  192. ImgElemT delta = c_max - c_min;
  193. float h; // 60 is eliminated
  194. if (delta == 0) {
  195. h = 0;
  196. } else {
  197. float delta_inv = 1.0f / delta;
  198. if (c_max == in.x) { // c_max == r
  199. h = delta_inv * (in.y - in.z); // (g-b)/delta % 6
  200. if (h < 0) {
  201. h += 6;
  202. }
  203. } else if (c_max == in.y) { // c_max == g
  204. h = delta_inv * (in.z - in.x) + 2; // (b-r)/delta + 2
  205. } else { // c_max == b
  206. h = delta_inv * (in.x - in.y) + 4; // (r-g)/delta + 2
  207. }
  208. }
  209. float s;
  210. if (c_max == 0) {
  211. s = 0;
  212. } else {
  213. s = (float) delta / c_max;
  214. }
  215. constexpr float v_factor = 1.0f / type_max_value<ImgElemT>::value;
  216. float v = v_factor * (float) c_max;
  217. // enhance V channel
  218. v = ext.norm_factor * log(v / ext.log_avg + 1);
  219. // convert HSV to RGB
  220. // https://www.rapidtables.com/convert/color/hsv-to-rgb.html
  221. float c = v * s;
  222. float x = c * (1 - fabsf(fmodf(h, 2) - 1)); // c * (1 - |h % 2 - 1|)
  223. float m = v - c;
  224. float r, g, b;
  225. switch ((uint8_t) h) {
  226. case 0: {
  227. r = c;
  228. g = x;
  229. b = 0;
  230. break;
  231. }
  232. case 1: {
  233. r = x;
  234. g = c;
  235. b = 0;
  236. break;
  237. }
  238. case 2: {
  239. r = 0;
  240. g = c;
  241. b = x;
  242. break;
  243. }
  244. case 3: {
  245. r = 0;
  246. g = x;
  247. b = c;
  248. break;
  249. }
  250. case 4: {
  251. r = x;
  252. g = 0;
  253. b = c;
  254. break;
  255. }
  256. case 5: {
  257. r = c;
  258. g = 0;
  259. b = x;
  260. break;
  261. }
  262. default: {
  263. assert(false);
  264. }
  265. }
  266. constexpr float out_factor = type_max_value<ImgElemT>::value;
  267. ImgT out;
  268. out.x = out_factor * (r + m);
  269. out.y = out_factor * (g + m);
  270. out.z = out_factor * (b + m);
  271. *p_out = out;
  272. }
  273. };
  274. // special kernels
  275. __global__ void prepare_enhance_coeff(float *p_max_v, float *p_sum_log_v, uint32_t n,
  276. enhance_coeff *p_out) {
  277. float max_v = *p_max_v;
  278. float sum_log_v = *p_sum_log_v;
  279. float log_avg = exp(sum_log_v / n);
  280. float norm_factor = 1.0f / (log(max_v / log_avg + 1));
  281. p_out->log_avg = log_avg;
  282. p_out->norm_factor = norm_factor;
  283. }
  284. // calling endpoints
  285. template<typename T>
  286. void call_reduce_max(T *in, T *out, size_t n,
  287. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  288. using FuncType = reduce_max_func<T>;
  289. constexpr T InitVal = std::numeric_limits<T>::min();
  290. auto helper_func = call_reduce_any<T, T, FuncType, FuncType, InitVal>;
  291. helper_func(in, out, n, block_size, grid_dim, stream);
  292. }
  293. template void call_reduce_max(float *, float *, size_t, uint16_t, uint16_t, cudaStream_t);
  294. template<typename T>
  295. void call_reduce_min(T *in, T *out, size_t n,
  296. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  297. using FuncType = reduce_min_func<T>;
  298. constexpr T InitVal = std::numeric_limits<T>::max();
  299. auto helper_func = call_reduce_any<T, T, FuncType, FuncType, InitVal>;
  300. helper_func(in, out, n, block_size, grid_dim, stream);
  301. }
  302. template void call_reduce_min(float *, float *, size_t, uint16_t, uint16_t, cudaStream_t);
  303. template<typename T>
  304. void call_reduce_sum(T *in, T *out, size_t n,
  305. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  306. using FuncType = reduce_sum_func<T>;
  307. auto helper_func = call_reduce_any<T, T, FuncType, FuncType, (T) 0>;
  308. helper_func(in, out, n, block_size, grid_dim, stream);
  309. }
  310. template void call_reduce_sum(float *, float *, size_t, uint16_t, uint16_t, cudaStream_t);
  311. template<typename T>
  312. void call_reduce_log_sum(T *in, T *out, size_t n,
  313. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  314. using UpdateFuncType = update_log_sum_func<T>;
  315. using ReduceFuncType = reduce_sum_func<T>;
  316. auto helper_func = call_reduce_any<T, T, UpdateFuncType, ReduceFuncType, (T) 0>;
  317. helper_func(in, out, n, block_size, grid_dim, stream);
  318. }
  319. template void call_reduce_log_sum(float *, float *, size_t, uint16_t, uint16_t, cudaStream_t);
  320. template<typename InT, typename OutT>
  321. void call_rgb_extract_v(InT *in, OutT *out, size_t n,
  322. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  323. assert(n <= std::numeric_limits<uint32_t>::max());
  324. using FuncType = rgb_extract_v_func<InT, OutT>;
  325. elementwise_any<InT, OutT, FuncType><<<grid_dim, block_size, 0, stream>>>(in, out, n);
  326. }
  327. template void call_rgb_extract_v(uchar3 *, float *, size_t, uint16_t, uint16_t, cudaStream_t);
  328. void call_prepare_enhance_coeff(float *max_v, float *sum_log_v, uint32_t n,
  329. enhance_coeff *out, cudaStream_t stream) {
  330. prepare_enhance_coeff<<<1, 1, 0, stream>>>(max_v, sum_log_v, n, out);
  331. }
  332. void call_enhance_v(float *in, float *out, size_t n, enhance_coeff *ext,
  333. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  334. assert(n <= std::numeric_limits<uint32_t>::max());
  335. auto kernel_func = elementwise_ext_any<float, float, enhance_coeff, enhance_v_func>;
  336. kernel_func<<<grid_dim, block_size, 0, stream>>>(in, out, n, ext);
  337. }
  338. template<typename ImgT>
  339. void call_enhance_image(ImgT *in, ImgT *out, size_t n, enhance_coeff *ext,
  340. uint16_t block_size, uint16_t grid_dim, cudaStream_t stream) {
  341. assert(n <= std::numeric_limits<uint32_t>::max());
  342. using FuncType = enhance_image_func<ImgT>;
  343. auto kernel_func = elementwise_ext_any<ImgT, ImgT, enhance_coeff, FuncType>;
  344. kernel_func<<<grid_dim, block_size, 0, stream>>>(in, out, n, ext);
  345. }
  346. template void call_enhance_image(uchar3 *, uchar3 *, size_t, enhance_coeff *, uint16_t, uint16_t, cudaStream_t);