ndarray.hpp 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #ifndef NDARRAY_H
  2. #define NDARRAY_H
  3. #include "cuda_runtime.h"
  4. #include <cassert>
  5. #include <cstdint>
  6. template<size_t N>
  7. struct ndarray_base {
  8. using index_type = uint32_t;
  9. index_type shape[N] = {};
  10. index_type strides[N] = {};
  11. template<size_t M = N> requires(M == 1)
  12. __host__ __device__ [[nodiscard]] index_type size() const {
  13. return shape[0];
  14. }
  15. template<size_t M = N> requires(M >= 2)
  16. __host__ __device__ [[nodiscard]] index_type width() const {
  17. return shape[0];
  18. }
  19. template<size_t M = N> requires(M >= 2)
  20. __host__ __device__ [[nodiscard]] index_type height() const {
  21. return shape[1];
  22. }
  23. template<size_t M = N> requires(M >= 3)
  24. __host__ __device__ [[nodiscard]] index_type depth() const {
  25. return shape[2];
  26. }
  27. template<size_t M = N> requires(M >= 2)
  28. __host__ __device__ [[nodiscard]] index_type pitch() const {
  29. return strides[1];
  30. }
  31. template<size_t M = N> requires(M >= 2)
  32. __host__ __device__ [[nodiscard]] size_t byte_width() const {
  33. return strides[0] * shape[0];
  34. }
  35. };
  36. template<typename T, size_t N>
  37. struct ndarray : ndarray_base<N> {
  38. using base_type = ndarray_base<N>;
  39. using typename base_type::index_type;
  40. using base_type::shape;
  41. using base_type::strides;
  42. void *data = nullptr;
  43. template<typename... Dims>
  44. requires(sizeof...(Dims) == N)
  45. __host__ __device__ T *ptr(Dims... ds) {
  46. index_type indices[] = {ds...};
  47. index_type offset = 0;
  48. for (auto i = 0; i < N; i++) {
  49. assert(indices[i] < shape[i]);
  50. offset += indices[i] * strides[i];
  51. }
  52. return (T *) ((uint8_t *) data + offset);
  53. }
  54. };
  55. template<typename U, typename T, size_t N>
  56. ndarray<U, N> type_cast(ndarray<T, N> arr) {
  57. assert(sizeof(U) <= arr.strides[0]);
  58. using ret_type = ndarray<U, N>;
  59. return *(ret_type *) &arr;
  60. }
  61. #endif //NDARRAY_H