#ifndef NDARRAY_H #define NDARRAY_H #include "cuda_runtime.h" #include #include template struct ndarray_base { using index_type = uint32_t; index_type shape[N] = {}; index_type strides[N] = {}; template requires(M == 1) __host__ __device__ [[nodiscard]] index_type size() const { return shape[0]; } template requires(M >= 2) __host__ __device__ [[nodiscard]] index_type width() const { return shape[0]; } template requires(M >= 2) __host__ __device__ [[nodiscard]] index_type height() const { return shape[1]; } template requires(M >= 3) __host__ __device__ [[nodiscard]] index_type depth() const { return shape[2]; } template requires(M >= 2) __host__ __device__ [[nodiscard]] index_type pitch() const { return strides[1]; } template requires(M >= 2) __host__ __device__ [[nodiscard]] size_t byte_width() const { return strides[0] * shape[0]; } }; template struct ndarray : ndarray_base { using base_type = ndarray_base; using typename base_type::index_type; using base_type::shape; using base_type::strides; void *data = nullptr; template requires(sizeof...(Dims) == N) __host__ __device__ T *ptr(Dims... ds) { index_type indices[] = {ds...}; index_type offset = 0; for (auto i = 0; i < N; i++) { assert(indices[i] < shape[i]); offset += indices[i] * strides[i]; } return (T *) ((uint8_t *) data + offset); } }; template ndarray type_cast(ndarray arr) { assert(sizeof(U) <= arr.strides[0]); using ret_type = ndarray; return *(ret_type *) &arr; } #endif //NDARRAY_H