| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- #ifndef NDARRAY_H
- #define NDARRAY_H
- #include "cuda_runtime.h"
- #include <cassert>
- #include <cstdint>
- template<size_t N>
- struct ndarray_base {
- using index_type = uint32_t;
- index_type shape[N] = {};
- index_type strides[N] = {};
- template<size_t M = N> requires(M == 1)
- __host__ __device__ [[nodiscard]] index_type size() const {
- return shape[0];
- }
- template<size_t M = N> requires(M >= 2)
- __host__ __device__ [[nodiscard]] index_type width() const {
- return shape[0];
- }
- template<size_t M = N> requires(M >= 2)
- __host__ __device__ [[nodiscard]] index_type height() const {
- return shape[1];
- }
- template<size_t M = N> requires(M >= 3)
- __host__ __device__ [[nodiscard]] index_type depth() const {
- return shape[2];
- }
- template<size_t M = N> requires(M >= 2)
- __host__ __device__ [[nodiscard]] index_type pitch() const {
- return strides[1];
- }
- template<size_t M = N> requires(M >= 2)
- __host__ __device__ [[nodiscard]] size_t byte_width() const {
- return strides[0] * shape[0];
- }
- };
- template<typename T, size_t N>
- struct ndarray : ndarray_base<N> {
- using base_type = ndarray_base<N>;
- using typename base_type::index_type;
- using base_type::shape;
- using base_type::strides;
- void *data = nullptr;
- template<typename... Dims>
- requires(sizeof...(Dims) == N)
- __host__ __device__ T *ptr(Dims... ds) {
- index_type indices[] = {ds...};
- index_type offset = 0;
- for (auto i = 0; i < N; i++) {
- assert(indices[i] < shape[i]);
- offset += indices[i] * strides[i];
- }
- return (T *) ((uint8_t *) data + offset);
- }
- };
- template<typename U, typename T, size_t N>
- ndarray<U, N> type_cast(ndarray<T, N> arr) {
- assert(sizeof(U) <= arr.strides[0]);
- using ret_type = ndarray<U, N>;
- return *(ret_type *) &arr;
- }
- #endif //NDARRAY_H
|