소스 검색

Embedded AI module FastSAM.

jcsyshc 1 년 전
부모
커밋
75a157e2f9

+ 10 - 0
CMakeLists.txt

@@ -4,6 +4,8 @@ project(DepthGuide)
 set(CMAKE_CXX_STANDARD 20)
 
 add_executable(${PROJECT_NAME} src/main.cpp
+        src/ai/impl/fast_sam.cpp
+        src/ai/impl/zmq_client.cpp
         src/image_process/impl/image_process_ui.cpp
         src/image_process/impl/versatile_convertor.cpp
         src/impl/main_impl.cpp
@@ -142,6 +144,10 @@ if (WIN32)
     target_link_directories(${PROJECT_NAME} PRIVATE C:/BuildEssentials/VS2019Libs/lib)
 endif ()
 
+# JSON config
+find_package(nlohmann_json REQUIRED)
+target_link_libraries(${PROJECT_NAME} nlohmann_json::nlohmann_json)
+
 # Boost config
 find_package(Boost REQUIRED COMPONENTS iostreams)
 target_include_directories(${PROJECT_NAME} PRIVATE ${Boost_INCLUDE_DIRS})
@@ -226,6 +232,10 @@ target_link_libraries(${PROJECT_NAME} CUDA::nvjpeg)
 find_package(assimp REQUIRED)
 target_link_libraries(${PROJECT_NAME} assimp::assimp)
 
+# ZeroMQ config
+find_package(azmq REQUIRED)
+target_link_libraries(${PROJECT_NAME} Azmq::azmq)
+
 # Sophiar2 config
 if (WIN32)
     set(Sophiar2DIR D:/Program/Robot/Sophiar2)

+ 34 - 0
src/ai/fast_sam.h

@@ -0,0 +1,34 @@
+#ifndef DEPTHGUIDE_FAST_SAM_H
+#define DEPTHGUIDE_FAST_SAM_H
+
+#include "core/cuda_helper.hpp"
+#include "core/object_manager.h"
+
+#include <boost/asio/io_context.hpp>
+
+#include <memory>
+
+class fast_sam {
+public:
+
+    using io_context = boost::asio::io_context;
+
+    struct create_config {
+        obj_name_type img_in = invalid_obj_name; // image_u8c3
+        obj_name_type mask_out = invalid_obj_name; // image_u8c1
+        CUcontext *cuda_ctx = nullptr;
+        io_context *ctx = nullptr;
+    };
+
+    explicit fast_sam(create_config conf);
+
+    ~fast_sam();
+
+    void show();
+
+private:
+    struct impl;
+    std::unique_ptr<impl> pimpl;
+};
+
+#endif //DEPTHGUIDE_FAST_SAM_H

+ 82 - 0
src/ai/impl/fast_sam.cpp

@@ -0,0 +1,82 @@
+#include "fast_sam_impl.h"
+#include "core/imgui_utility.hpp"
+
+#include <boost/asio/post.hpp>
+
+using boost::asio::post;
+
+fast_sam::impl::impl(create_config _conf) {
+    conf = _conf;
+    ctx = conf.ctx;
+}
+
+void fast_sam::impl::image_callback(obj_name_type name) {
+    assert(client != nullptr);
+    if (!client->is_ideal()) return;
+
+    assert(name == conf.img_in);
+    auto img = OBJ_QUERY(image_u8c3, name);
+    auto img_v2 = create_image(img);
+    auto req = zmq_client::request_type();
+    req.img_list.push_back(img_v2);
+
+    using namespace nlohmann;
+    auto extra = json();
+    extra["prompt_type"] = "point";
+    extra["points"] = json::array({json::array(
+            {img_v2->width() / 2, img_v2->height() / 2})}); // center point
+    extra["point_label"] = json::array({1});
+    extra["imgsz"] = 1024;
+
+    req.extra = extra;
+    client->request(req);
+}
+
+void fast_sam::impl::reply_callback(const zmq_client::reply_type &rep) {
+    assert(!rep.img_list.empty());
+    auto img_v2 = rep.img_list.front();
+    img_v2->set_meta(META_COLOR_FMT, COLOR_BW);
+//    auto img = img_v2->v1<uchar1>();
+    OBJ_SAVE(conf.mask_out, img_v2);
+}
+
+void fast_sam::impl::start() {
+    auto client_conf = zmq_client::create_config{
+            .serv_addr = "ipc:///tmp/fast_sam_v1",
+//            .serv_addr = "tcp://127.0.0.1:7899",
+            .reply_cb = [this](auto rep) { reply_callback(rep); },
+            .cuda_ctx  = conf.cuda_ctx,
+    };
+    assert(client == nullptr);
+    client = std::make_unique<zmq_client>(client_conf);
+
+    img_conn = OBJ_SIG(conf.img_in)
+            ->connect([this](auto name) { image_callback(name); });
+}
+
+void fast_sam::impl::stop() {
+    img_conn.disconnect();
+    client = nullptr;
+}
+
+void fast_sam::impl::show() {
+    if (client == nullptr) {
+        if (ImGui::Button("Start")) {
+            post(*ctx, [this] { start(); });
+        }
+    } else {
+        if (ImGui::Button("Stop")) {
+            post(*ctx, [this] { stop(); });
+        }
+    }
+}
+
+fast_sam::fast_sam(create_config conf)
+        : pimpl(std::make_unique<impl>(conf)) {
+}
+
+fast_sam::~fast_sam() = default;
+
+void fast_sam::show() {
+    pimpl->show();
+}

+ 30 - 0
src/ai/impl/fast_sam_impl.h

@@ -0,0 +1,30 @@
+#ifndef DEPTHGUIDE_FAST_SAM_IMPL_H
+#define DEPTHGUIDE_FAST_SAM_IMPL_H
+
+#include "ai/fast_sam.h"
+#include "ai/zmq_client.h"
+#include "core/image_utility.hpp"
+
+struct fast_sam::impl {
+
+    create_config conf;
+    io_context *ctx = nullptr;
+
+    obj_conn_type img_conn;
+    std::unique_ptr<zmq_client> client;
+
+    explicit impl(create_config conf);
+
+    void image_callback(obj_name_type name);
+
+    void reply_callback(const zmq_client::reply_type &rep);
+
+    void start();
+
+    void stop();
+
+    void show();
+
+};
+
+#endif //DEPTHGUIDE_FAST_SAM_IMPL_H

+ 208 - 0
src/ai/impl/zmq_client.cpp

@@ -0,0 +1,208 @@
+#include "zmq_client_impl.h"
+
+#include <boost/asio/post.hpp>
+#include <boost/endian.hpp>
+
+using boost::asio::buffer;
+using boost::asio::post;
+using boost::system::error_code;
+
+namespace zmq_client_impl {
+
+//    zmq::context_t zmq_ctx;
+
+    const char *cv_type_to_dtype(int type) {
+        static constexpr bool le =
+                boost::endian::order::native == boost::endian::order::little;
+        switch (CV_MAT_DEPTH(type)) {
+            // @formatter:off
+            case CV_8U:  { return "|u1"; }
+            case CV_16U: { return le ? "<u2" : ">u2"; }
+            case CV_32F: { return le ? "<f4" : ">f4"; }
+            // @formatter:on
+            default: {
+                RET_ERROR_E;
+            }
+        }
+    }
+
+    int dtype_to_cv_type(const std::string &dtype, int c) {
+        // @formatter:off
+        if (dtype == "|b1") { return CV_8UC(c); }
+        if (dtype == "|u1") { return CV_8UC(c); }
+        if (dtype == "<u2") { return CV_16UC(c); }
+        if (dtype == "<f4") { return CV_32FC(c); }
+        // @formatter:on
+        RET_ERROR_E;
+    }
+
+    data_type encode_request(const zmq_client::request_type &req,
+                             smart_cuda_stream *stream) {
+        using namespace nlohmann;
+
+        // request memory transfer from cuda to host in advance
+        for (auto &img: req.img_list) {
+            img->memory(MEM_HOST, stream);
+        }
+
+        auto img_list = json::array();
+        assert(!req.img_list.empty());
+        for (auto k = 0; k < req.img_list.size(); ++k) {
+            auto &img = req.img_list[k];
+            auto img_info = json::object();
+            img_info["dtype"] = cv_type_to_dtype(img->cv_type());
+            img_info["shape"] = json::array(
+                    {img->height(), img->width(), CV_MAT_CN(img->cv_type())});
+            img_info["transfer"] = "direct";
+//            img_info["id"] = k;
+            img_list.push_back(img_info);
+        }
+
+        auto head = nlohmann::json();
+        head["model"] = req.model_name;
+        head["images"] = img_list;
+        head["extra"] = req.extra;
+        auto head_str = head.dump();
+        size_t head_size = head_str.length();
+
+        size_t total_size = sizeof(size_t) + head_size;
+        for (auto &img: req.img_list) {
+            total_size += sizeof(size_t) + img->size_in_bytes();
+        }
+        auto ret = data_type(total_size);
+
+        auto writer = network_writer(ret);
+        writer << head_size << head_str;
+        for (auto &img: req.img_list) {
+            size_t img_size = img->size_in_bytes();
+            writer << img_size;
+            auto img_data = img->memory(MEM_HOST, stream);
+            assert(img_data.is_continuous());
+            assert(img_data.width * img_data.height == img->size_in_bytes());
+            CUDA_API_CHECK(cudaStreamSynchronize(stream->cuda));
+            writer.write_data((uint8_t *) img_data.start_ptr(),
+                              img->size_in_bytes());
+        }
+
+        assert(writer.empty());
+        return ret;
+    }
+
+    zmq_client::reply_type decode_reply(const data_type &data) {
+        using namespace nlohmann;
+
+        auto reader = network_reader(data);
+        auto head_size = reader.read_value<size_t>();
+        auto head_str = reader.read_std_string(head_size);
+        auto head = json::parse(head_str);
+
+        auto ret = zmq_client::reply_type();
+        ret.extra = head["extra"];
+
+        for (auto k = 0; k < head["images"].size(); ++k) {
+            auto &img_info = head["images"][k];
+//            assert(img_info["id"] == k);
+            assert(img_info["transfer"] == "direct");
+
+            auto img_shape = img_info["shape"];
+            assert(img_shape.is_array());
+            assert(img_shape.size() == 2
+                   || img_shape.size() == 3);
+            auto img_height = img_shape[0].get<int>();
+            auto img_width = img_shape[1].get<int>();
+            int img_channels = 1;
+            if (img_shape.size() == 3) {
+                img_channels = img_shape[2].get<int>();
+            }
+
+            auto img_dtype = img_info["dtype"].get<std::string>();
+            auto img_type = dtype_to_cv_type(img_dtype, img_channels);
+            auto img = create_image(cv::Size(img_width, img_height), img_type);
+
+            auto img_mem = img->memory(MEM_HOST);
+            auto img_bytes = reader.read_value<size_t>();
+            assert(img_bytes == img->size_in_bytes());
+            assert(img_mem.is_continuous());
+            reader.read_data((uint8_t *) img_mem.start_ptr(), img_bytes);
+            img->host_modified();
+            ret.img_list.push_back(img);
+        }
+
+        return ret;
+    }
+
+}
+
+zmq_client::impl::impl(const create_config &conf) {
+    aux_ctx = std::make_unique<io_context>();
+    aux_thread = std::make_unique<std::thread>([=, this]() {
+        aux_thread_work(conf);
+    });
+}
+
+zmq_client::impl::~impl() {
+    aux_ctx->stop();
+    aux_thread->join();
+}
+
+void zmq_client::impl::aux_thread_work(const create_config &conf) {
+    CUDA_API_CHECK(cuCtxSetCurrent(*conf.cuda_ctx));
+    reply_cb = conf.reply_cb;
+    socket = std::make_unique<azmq::req_socket>(*aux_ctx, true);
+    socket->connect(conf.serv_addr);
+    auto blocker = boost::asio::make_work_guard(*aux_ctx);
+    aux_ctx->run();
+
+    // cleanup
+    socket = nullptr;
+}
+
+void zmq_client::impl::aux_on_request(const request_type &req) {
+    assert(!socket_busy.test());
+    socket_busy.test_and_set();
+
+    // send request
+    auto req_data = encode_request(req, &aux_stream);
+    auto req_buf = buffer(req_data.start_ptr(), req_data.size);
+    socket->async_send(req_buf, [=, this](error_code ec, size_t size) {
+        assert(!ec);
+        assert(size == req_data.size);
+        try_recv_rep();
+    });
+}
+
+void zmq_client::impl::try_recv_rep() {
+    static constexpr auto max_rep_size = 32 * 1024 * 1024; // 32MB
+    auto rep_data = data_type(max_rep_size);
+    auto rep_buf = buffer(rep_data.start_ptr(), rep_data.size);
+    socket->async_receive(rep_buf, [=, this](error_code ec, size_t size) mutable {
+        assert(!ec);
+        rep_data.shrink(size);
+        on_reply(rep_data);
+    });
+}
+
+void zmq_client::impl::on_reply(const data_type &data) {
+    auto rep = decode_reply(data);
+    reply_cb(rep);
+    socket_busy.clear();
+}
+
+void zmq_client::impl::on_request(const request_type &req) {
+    if (socket_busy.test()) return;
+    post(*aux_ctx, [req = req, this]() { aux_on_request(req); });
+}
+
+zmq_client::zmq_client(const create_config &conf)
+        : pimpl(std::make_unique<impl>(conf)) {
+}
+
+zmq_client::~zmq_client() = default;
+
+void zmq_client::request(const request_type &req) {
+    pimpl->on_request(req);
+}
+
+bool zmq_client::is_ideal() const {
+    return !pimpl->socket_busy.test();
+}

+ 73 - 0
src/ai/impl/zmq_client_impl.h

@@ -0,0 +1,73 @@
+#ifndef DEPTHGUIDE_ZMQ_CLIENT_IMPL_H
+#define DEPTHGUIDE_ZMQ_CLIENT_IMPL_H
+
+#include "ai/zmq_client.h"
+#include "network/binary_utility.hpp"
+
+#include <azmq/socket.hpp>
+
+#include <atomic>
+#include <thread>
+
+namespace zmq_client_impl {
+
+//    extern zmq::context_t zmq_ctx;
+
+    enum image_transfer_mode {
+        DIRECT = 0,
+        HOST_SHARED = 1,
+        CUDA_SHARED = 2
+    };
+
+    const char *cv_type_to_dtype(int type);
+
+    int dtype_to_cv_type(const std::string &dtype, int channels);
+
+    /* Head Size [4 bytes]
+     * Head Json [n bytes]
+     * Img#1 Size [4 bytes]
+     * Img#1 Data [n bytes]
+     * Img#n ... */
+    data_type encode_request(const zmq_client::request_type &req,
+                             smart_cuda_stream *stream);
+
+    zmq_client::reply_type decode_reply(const data_type &data);
+
+}
+
+using namespace zmq_client_impl;
+
+struct zmq_client::impl {
+
+    std::unique_ptr<std::thread> aux_thread;
+//    image_transfer_mode trans_mode = DIRECT;
+
+    // for aux thread
+    using io_context = boost::asio::io_context;
+    std::unique_ptr<io_context> aux_ctx;
+
+    std::unique_ptr<azmq::req_socket> socket;
+    std::atomic_flag socket_busy = false;
+
+    smart_cuda_stream aux_stream;
+    reply_cb_type reply_cb;
+
+    explicit impl(const create_config &conf);
+
+    ~impl();
+
+    void on_request(const request_type &req);
+
+    // for auxiliary thread
+
+    void aux_thread_work(const create_config &conf);
+
+    void aux_on_request(const request_type &req);
+
+    void try_recv_rep();
+
+    void on_reply(const data_type &data);
+
+};
+
+#endif //DEPTHGUIDE_ZMQ_CLIENT_IMPL_H

+ 54 - 0
src/ai/zmq_client.h

@@ -0,0 +1,54 @@
+#ifndef DEPTHGUIDE_ZMQ_CLIENT_H
+#define DEPTHGUIDE_ZMQ_CLIENT_H
+
+#include "core/cuda_helper.hpp"
+#include "core/image_utility_v2.h"
+
+#include <boost/container/static_vector.hpp>
+
+#include <nlohmann/json.hpp>
+
+#include <memory>
+
+class zmq_client {
+public:
+
+    static constexpr auto max_img_cnt = 2;
+    using img_list_type =
+            boost::container::static_vector<image_ptr, max_img_cnt>;
+
+    struct request_type {
+        img_list_type img_list;
+        std::string model_name;
+        nlohmann::json extra;
+    };
+
+    struct reply_type {
+        img_list_type img_list;
+        nlohmann::json extra;
+    };
+
+    // reply will be call in another thread
+    using reply_cb_type = std::function<void(reply_type)>;
+
+    struct create_config {
+        std::string serv_addr; // like "ipc://fast_sam_v1"
+        reply_cb_type reply_cb;
+        CUcontext *cuda_ctx = nullptr;
+    };
+
+    explicit zmq_client(const create_config &conf);
+
+    ~zmq_client();
+
+    // requests will be dropped if the previous request has not been finished.
+    void request(const request_type &req);
+
+    bool is_ideal() const;
+
+private:
+    struct impl;
+    std::unique_ptr<impl> pimpl;
+};
+
+#endif //DEPTHGUIDE_ZMQ_CLIENT_H

+ 1 - 0
src/core/image_utility.hpp

@@ -361,6 +361,7 @@ private:
     image_info_type<T> host_info;
     image_info_type<T> cuda_info;
 
+//    [[deprecated("Use REC_CREATE() and SYNC_CREATE() instead.")]]
     cudaEvent_t event = nullptr;
 
     friend class generic_image;

+ 28 - 2
src/core/image_utility_v2.h

@@ -17,6 +17,17 @@ enum pixel_format_enum {
     PIX_NV12,
 };
 
+enum meta_key_enum {
+    META_COLOR_FMT, // color_format
+};
+
+enum color_format : uint8_t {
+    COLOR_BW, // only 0 and 1
+    COLOR_RGB,
+    COLOR_RGBA,
+    COLOR_NV12
+};
+
 class generic_image;
 
 struct image_memory {
@@ -30,6 +41,11 @@ struct image_memory {
     void *at(int row = 0, int col = 0, int component = 0);
 
     void modified(smart_cuda_stream *stream = nullptr);
+
+    bool is_continuous() const {
+        return width == pitch;
+    }
+
 };
 
 // collection of a same image in multiple devices
@@ -50,6 +66,9 @@ public:
     static pointer create(cv::Size size, int type,
                           pixel_format_enum pixel = PIX_NORMAL);
 
+    template<typename T>
+    static pointer create(const std::shared_ptr<smart_image<T>> &img);
+
     size_t width() const {
         return size().width;
     }
@@ -80,6 +99,8 @@ public:
     template<typename T>
     image_type_v2<T> cuda(smart_cuda_stream *stream = nullptr);
 
+    image_mem_info memory_v1(smart_cuda_stream *stream = nullptr) const;
+
     image_memory memory(memory_location loc,
                         smart_cuda_stream *stream = nullptr);
 
@@ -110,9 +131,14 @@ public:
 
 using image_ptr = generic_image::pointer;
 
-image_ptr create_image(cv::Size size, int type,
-                       pixel_format_enum pixel = PIX_NORMAL) {
+inline image_ptr create_image(cv::Size size, int type,
+                              pixel_format_enum pixel = PIX_NORMAL) {
     return generic_image::create(size, type, pixel);
 }
 
+template<typename T>
+inline image_ptr create_image(const std::shared_ptr<smart_image<T>> &img) {
+    return generic_image::create(img);
+}
+
 #endif //DEPTHGUIDE_IMAGE_UTILITY_V2_H

+ 72 - 20
src/core/impl/image_utility_v2.cpp

@@ -144,6 +144,26 @@ void generic_image::impl::create_cuda(smart_cuda_stream *stream) {
     }
 }
 
+image_mem_info generic_image::impl::get_memory_v1(smart_cuda_stream *stream) const {
+    auto ret = image_mem_info{
+            .width = width_in_bytes(),
+            .height = (size_t) size.height
+    };
+    if (store_cuda.ptr != nullptr) {
+        ret.loc = MEM_CUDA;
+        SYNC_CREATE(store_cuda.ptr, stream);
+        ret.ptr = store_cuda.ptr;
+        ret.pitch = store_cuda.pitch;
+    } else {
+        assert(store_host.ptr != nullptr);
+        ret.loc = MEM_HOST;
+        SYNC_CREATE(store_host.ptr, stream);
+        ret.ptr = store_host.ptr;
+        ret.pitch = store_host.pitch;
+    }
+    return ret;
+}
+
 image_memory generic_image::impl::get_memory(memory_location loc,
                                              smart_cuda_stream *stream) {
     auto ret = image_memory();
@@ -188,7 +208,7 @@ image_type_v2<T> generic_image::impl::get_image_type_v2(smart_cuda_stream *strea
     assert(size.width <= std::numeric_limits<ushort>::max());
     assert(size.height <= std::numeric_limits<ushort>::max());
     return image_type_v2<T>(
-            store_cuda.ptr.get(), size.width, size.height, store_cuda.pitch);
+            (T *) store_cuda.ptr.get(), size.width, size.height, store_cuda.pitch);
 }
 
 template<typename T>
@@ -198,19 +218,19 @@ std::shared_ptr<smart_image<T>> generic_image::impl::get_image_v1() const {
     using info_type = image_info_type<T>;
     if (store_host.ptr != nullptr) {
         auto host_info = info_type{
-                .ptr = store_host.ptr, .loc = MEM_HOST,
-                .size = size, .pitch = store_host.pitch,
+                .ptr = std::reinterpret_pointer_cast<T>(store_host.ptr),
+                .loc = MEM_HOST, .size = size, .pitch = store_host.pitch,
         };
-        ret = std::make_shared<ret_type>(host_info);
+        ret = std::make_shared<smart_image<T>>(host_info);
     }
 
     if (store_cuda.ptr != nullptr) {
         auto cuda_info = info_type{
-                .ptr = store_cuda.ptr, .loc= MEM_CUDA,
-                .size = size, .pitch = store_cuda.pitch,
+                .ptr = std::reinterpret_pointer_cast<T>(store_cuda.ptr),
+                .loc= MEM_CUDA, .size = size, .pitch = store_cuda.pitch,
         };
         if (ret == nullptr) {
-            ret = std::make_shared<ret_type>(cuda_info);
+            ret = std::make_shared<smart_image<T>>(cuda_info);
         } else {
             ret->cuda_info = cuda_info;
         }
@@ -220,6 +240,18 @@ std::shared_ptr<smart_image<T>> generic_image::impl::get_image_v1() const {
     return ret;
 }
 
+template<typename T>
+void generic_image::impl::create_from_v1(const std::shared_ptr<smart_image<T>> &img) {
+    if (img->host_info.ptr != nullptr) {
+        store_host.ptr = img->host_info.ptr;
+        store_host.pitch = img->host_info.pitch;
+    }
+    if (img->cuda_info.ptr != nullptr) {
+        store_cuda.ptr = img->cuda_info.ptr;
+        store_cuda.pitch = img->cuda_info.pitch;
+    }
+}
+
 void generic_image::impl::sub_image_inplace(int row, int col, int width, int height) {
     // sub-image of other formats are not implemented
     assert(pix_fmt == PIX_NORMAL);
@@ -275,6 +307,13 @@ generic_image::pointer generic_image::create(cv::Size size, int type, pixel_form
     return create(conf);
 }
 
+template<typename T>
+generic_image::pointer generic_image::create(const std::shared_ptr<smart_image<T>> &img) {
+    auto ret = create(img->size(), get_cv_type<T>());
+    ret->pimpl->create_from_v1(img);
+    return ret;
+}
+
 cv::Size generic_image::size() const {
     return pimpl->display_size();
 }
@@ -295,6 +334,10 @@ pixel_format_enum generic_image::pixel_format() const {
     return pimpl->pix_fmt;
 }
 
+image_mem_info generic_image::memory_v1(smart_cuda_stream *stream) const {
+    return pimpl->get_memory_v1(stream);
+}
+
 image_memory generic_image::memory(memory_location loc,
                                    smart_cuda_stream *stream) {
     return pimpl->get_memory(loc, stream);
@@ -314,12 +357,12 @@ image_type_v2<T> generic_image::cuda(smart_cuda_stream *stream) {
 }
 
 // @formatter:off
-template<> image_type_v2<uchar1> generic_image::cuda(smart_cuda_stream *stream);
-template<> image_type_v2<uchar2> generic_image::cuda(smart_cuda_stream *stream);
-template<> image_type_v2<uchar3> generic_image::cuda(smart_cuda_stream *stream);
-template<> image_type_v2<uchar4> generic_image::cuda(smart_cuda_stream *stream);
-template<> image_type_v2<ushort1> generic_image::cuda(smart_cuda_stream *stream);
-template<> image_type_v2<float1> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<uchar1> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<uchar2> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<uchar3> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<uchar4> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<ushort1> generic_image::cuda(smart_cuda_stream *stream);
+template image_type_v2<float1> generic_image::cuda(smart_cuda_stream *stream);
 // @formatter:on
 
 template<typename T>
@@ -328,12 +371,12 @@ std::shared_ptr<smart_image<T>> generic_image::v1() const {
 }
 
 // @formatter:off
-template<> std::shared_ptr<smart_image<uchar1>> generic_image::v1() const;
-template<> std::shared_ptr<smart_image<uchar2>> generic_image::v1() const;
-template<> std::shared_ptr<smart_image<uchar3>> generic_image::v1() const;
-template<> std::shared_ptr<smart_image<uchar4>> generic_image::v1() const;
-template<> std::shared_ptr<smart_image<ushort1>> generic_image::v1() const;
-template<> std::shared_ptr<smart_image<float1>> generic_image::v1() const;
+template std::shared_ptr<smart_image<uchar1>> generic_image::v1() const;
+template std::shared_ptr<smart_image<uchar2>> generic_image::v1() const;
+template std::shared_ptr<smart_image<uchar3>> generic_image::v1() const;
+template std::shared_ptr<smart_image<uchar4>> generic_image::v1() const;
+template std::shared_ptr<smart_image<ushort1>> generic_image::v1() const;
+template std::shared_ptr<smart_image<float1>> generic_image::v1() const;
 // @formatter:on
 
 generic_image::pointer generic_image::shallow_clone() const {
@@ -360,4 +403,13 @@ void generic_image::host_modified(smart_cuda_stream *stream) {
 
 void generic_image::cuda_modified(smart_cuda_stream *stream) {
     pimpl->cuda_modified(stream);
-}
+}
+
+// @formatter:off
+template image_ptr create_image(const image_u8c1 &img);
+template image_ptr create_image(const image_u8c2 &img);
+template image_ptr create_image(const image_u8c3 &img);
+template image_ptr create_image(const image_u8c4 &img);
+template image_ptr create_image(const image_u16c1 &img);
+template image_ptr create_image(const image_f32c1 &img);
+// @formatter:on

+ 5 - 0
src/core/impl/image_utility_v2_impl.h

@@ -60,6 +60,8 @@ struct generic_image::impl : public meta_base::impl {
     // imply that stream want to use cuda
     void create_cuda(smart_cuda_stream *stream);
 
+    image_mem_info get_memory_v1(smart_cuda_stream *stream) const;
+
     image_memory get_memory(memory_location loc,
                             smart_cuda_stream *stream);
 
@@ -73,6 +75,9 @@ struct generic_image::impl : public meta_base::impl {
     template<typename T>
     std::shared_ptr<smart_image<T>> get_image_v1() const;
 
+    template<typename T>
+    void create_from_v1(const std::shared_ptr<smart_image<T>> &img);
+
     // use after copy of impl
     void sub_image_inplace(int row = 0, int col = 0,
                            int width = -1, int height = -1);

+ 2 - 0
src/core/impl/memory_pool.cpp

@@ -187,6 +187,8 @@ void memory_pool::record_create(void *ptr, smart_cuda_stream *stream) {
 void memory_pool::sync_create(void *ptr, smart_cuda_stream *stream) {
     auto event = pimpl->get_event(ptr);
     if (stream == nullptr) {
+        // cudaEventSynchronize() should be used with cudaEventBlockingSync
+        assert(cudaEventQuery(event) == cudaSuccess);
         CUDA_API_CHECK(cudaEventSynchronize(event));
     } else {
         CUDA_API_CHECK(cudaStreamWaitEvent(stream->cuda, event));

+ 1 - 1
src/core/pc_utility.h

@@ -52,7 +52,7 @@ public:
 
     static pointer create(size_t size, pc_format_enum fmt);
 
-    // use to hide invisible points
+    // used to hide invisible points
     void shrink(size_t size);
 
     size_t size() const; // number of points

+ 46 - 8
src/impl/apps/depth_guide/depth_guide.cpp

@@ -14,6 +14,7 @@ app_depth_guide::app_depth_guide(const create_config &_conf) {
     auto fake_info = fake_color_config{.mode = FAKE_800P, .lower = 200, .upper = 1000};
     OBJ_SAVE(img_depth_fake_info, versatile_convertor_impl::encode_config(fake_info));
     OBJ_SAVE(img_out, image_u8c4());
+    OBJ_SAVE(img_obj_mask, image_ptr());
     OBJ_SAVE(pc_raw, pc_ptr());
 
     // initialize modules
@@ -29,6 +30,12 @@ app_depth_guide::app_depth_guide(const create_config &_conf) {
     };
     depth_encode = std::make_unique<versatile_convertor>(fake_conf);
 
+    auto ai_conf = fast_sam::create_config{
+            .img_in = img_color, .mask_out = img_obj_mask,
+            .cuda_ctx = conf.cuda_ctx, .ctx = conf.asio_ctx,
+    };
+    ai_seg = std::make_unique<fast_sam>(ai_conf);
+
     auto out_conf = stereo_augment_helper::create_config{
             .left_name = img_color, .right_name = img_depth_fake, .out_name = img_out,
             .stream = default_cuda_stream
@@ -37,12 +44,14 @@ app_depth_guide::app_depth_guide(const create_config &_conf) {
     out_combiner->fix_ui_config({.follow_image_size=true, .enable_halve_width=false});
 
     auto bg_viewer_conf = image_viewer::create_config{
-            .mode = VIEW_COLOR_DEPTH, .flip_y = true,
+            .mode = VIEW_STEREO, .flip_y = true,
             .stream = default_cuda_stream,
     };
-    auto &bg_extra_conf = bg_viewer_conf.extra.color_depth;
-    bg_extra_conf.c_name = img_color;
-    bg_extra_conf.d_name = img_depth;
+    auto &bg_extra_conf = bg_viewer_conf.extra.stereo;
+//    bg_extra_conf.c_name = img_color;
+    bg_extra_conf.left_name = img_color;
+    bg_extra_conf.c_fmt = COLOR_RGB;
+    bg_extra_conf.right_name = img_obj_mask;
     bg_viewer = std::make_unique<image_viewer>(bg_viewer_conf);
 
     auto out_streamer_conf = image_streamer::create_config{
@@ -53,6 +62,11 @@ app_depth_guide::app_depth_guide(const create_config &_conf) {
     };
     out_streamer = std::make_unique<image_streamer>(out_streamer_conf);
 
+    auto saver_conf = image_saver::create_config{.ctx = conf.asio_ctx};
+    saver_conf.img_list.emplace_back("Image", img_color);
+    saver_conf.img_list.emplace_back("Image Out", img_out);
+    out_saver = std::make_unique<image_saver>(saver_conf);
+
     auto pc_conf = pc_viewer::create_config{
             .name = pc_raw, .stream = default_cuda_stream,
     };
@@ -69,6 +83,11 @@ void app_depth_guide::show_ui() {
             orb_cam->show();
         }
 
+        if (ImGui::CollapsingHeader("AI Segmentation")) {
+            auto id_guard = imgui_id_guard("ai_segment");
+            ai_seg->show();
+        }
+
         if (ImGui::CollapsingHeader("Streamer")) {
             auto id_guard = imgui_id_guard("streamer");
             out_streamer->show();
@@ -76,8 +95,24 @@ void app_depth_guide::show_ui() {
 
         if (ImGui::CollapsingHeader("Debug")) {
             if (ImGui::TreeNode("Background")) {
-//                bg_viewer->show();
-                scene_viewer->show();
+                ImGui::RadioButton("Image", &enable_scene_viewer, false);
+                ImGui::SameLine();
+                ImGui::RadioButton("3D", &enable_scene_viewer, true);
+                ImGui::TreePop();
+            }
+            if (enable_scene_viewer) {
+                if (ImGui::TreeNode("3D Viewer")) {
+                    scene_viewer->show();
+                    ImGui::TreePop();
+                }
+            } else {
+                if (ImGui::TreeNode("Image Viewer")) {
+                    bg_viewer->show();
+                    ImGui::TreePop();
+                }
+            }
+            if (ImGui::TreeNode("Image Saver")) {
+                out_saver->show();
                 ImGui::TreePop();
             }
             if (ImGui::TreeNode("Memory Pool")) {
@@ -99,6 +134,9 @@ void app_depth_guide::show_ui() {
 }
 
 void app_depth_guide::render_background() {
-//    bg_viewer->render();
-    scene_viewer->render();
+    if (enable_scene_viewer) {
+        scene_viewer->render();
+    } else {
+        bg_viewer->render();
+    }
 }

+ 8 - 0
src/impl/apps/depth_guide/depth_guide.h

@@ -1,11 +1,13 @@
 #ifndef DEPTHGUIDE_DEPTH_GUIDE_H
 #define DEPTHGUIDE_DEPTH_GUIDE_H
 
+#include "ai/fast_sam.h"
 #include "image_process/versatile_convertor.h"
 #include "core/event_timer.h"
 #include "core/object_manager.h"
 #include "device/orb_camera_ui.h"
 #include "module/image_augment_helper.h"
+#include "module/image_saver.h"
 #include "module/image_streamer.h"
 #include "module/image_viewer.h"
 #include "module/pc_viewer.h"
@@ -41,9 +43,13 @@ private:
 
         // point cloud from device
         pc_raw,
+
+        // ai segment result
+        img_obj_mask
     };
 
     create_config conf;
+    int enable_scene_viewer = false;
 
     // modules
     std::unique_ptr<orb_camera_ui> orb_cam;
@@ -51,7 +57,9 @@ private:
     std::unique_ptr<versatile_convertor> depth_encode;
     std::unique_ptr<stereo_augment_helper> out_combiner;
     std::unique_ptr<image_streamer> out_streamer; // output streamer
+    std::unique_ptr<image_saver> out_saver;
     std::unique_ptr<pc_viewer> scene_viewer;
+    std::unique_ptr<fast_sam> ai_seg;
 
     // miscellaneous
     event_timer perf_timer; // performance timer

+ 32 - 6
src/module/impl/image_viewer.cpp

@@ -1,5 +1,6 @@
 #include "image_viewer_impl.h"
 #include "core/imgui_utility.hpp"
+#include "core/image_utility_v2.h"
 #include "render/render_texture.h"
 
 void image_viewer::impl::show_depth_only() {
@@ -14,6 +15,12 @@ void image_viewer::impl::show_depth_only() {
     ImGui::Checkbox("##manual_dep_range", &depth_conf.manual_depth_range);
 }
 
+void image_viewer::impl::show_overlay_alpha() {
+    ImGui::PushItemWidth(150);
+    ImGui::SliderFloat("Overlay Alpha", &overlay_alpha, 0.f, 1.f, "%.2f");
+    ImGui::PopItemWidth();
+}
+
 void image_viewer::impl::show_color_depth() {
     ImGui::RadioButton("Color", &chose_index, 0);
     ImGui::SameLine();
@@ -24,10 +31,8 @@ void image_viewer::impl::show_color_depth() {
     if (chose_index == 1 || chose_index == 2) { // depth or both
         show_depth_only();
     }
-
-    ImGui::PushItemWidth(150);
     if (chose_index == 2) { // both
-        ImGui::SliderFloat("Depth Alpha", &depth_overlay_alpha, 0.f, 1.f, "%.2f");
+        show_overlay_alpha();
     }
 }
 
@@ -35,6 +40,12 @@ void image_viewer::impl::show_stereo() {
     ImGui::RadioButton("Left", &chose_index, 0);
     ImGui::SameLine();
     ImGui::RadioButton("Right", &chose_index, 1);
+    ImGui::SameLine();
+    ImGui::RadioButton("Both", &chose_index, 2);
+
+    if (chose_index == 2) { // both
+        show_overlay_alpha();
+    }
 }
 
 void image_viewer::impl::show_custom() {
@@ -76,8 +87,18 @@ void image_viewer::impl::show() {
     ImGui::PopItemWidth();
 }
 
-void image_viewer::impl::render_color_obj(obj_name_type name) {
-    color_render.render(name, color_conf);
+void image_viewer::impl::render_color_obj(obj_name_type name, float alpha) {
+    auto ren_conf = color_conf;
+
+    ren_conf.alpha = alpha;
+    if (OBJ_TYPE(name) == typeid(image_ptr)) { // TODO: ugly hacked
+        auto img = OBJ_QUERY(image_ptr, name);
+        auto fmt = img->get_meta(META_COLOR_FMT);
+        if (fmt.has_value()) {
+            ren_conf.fmt = (color_format) *fmt;
+        }
+    }
+    color_render.render(name, ren_conf);
 }
 
 void image_viewer::impl::render_depth_obj(obj_name_type name, float alpha) {
@@ -101,7 +122,7 @@ void image_viewer::impl::render_color_depth() {
         }
         case 2: { // both
             render_color_obj(info.c_name);
-            render_depth_obj(info.d_name, depth_overlay_alpha);
+            render_depth_obj(info.d_name, overlay_alpha);
             break;
         }
         default: {
@@ -122,6 +143,11 @@ void image_viewer::impl::render_stereo() {
             render_color_obj(info.right_name);
             break;
         }
+        case 2: {
+            render_color_obj(info.left_name);
+            render_color_obj(info.right_name, overlay_alpha);
+            break;
+        }
         default: {
             RET_ERROR;
         }

+ 5 - 3
src/module/impl/image_viewer_impl.h

@@ -11,7 +11,7 @@ struct image_viewer::impl {
     /* for VIEW_COLOR_DEPTH
      *   0 = color, 1 = depth, 2 = both
      * for VIEW_STEREO
-     *   0 = left, 1 = right */
+     *   0 = left, 1 = right, 2 = both*/
     int chose_index = 0;
 
     using color_conf_type = color_image_render::config_type;
@@ -23,7 +23,9 @@ struct image_viewer::impl {
     color_conf_type color_conf = {};
     color_image_render color_render;
 
-    float depth_overlay_alpha = 0.5;
+    float overlay_alpha = 0.5;
+
+    void show_overlay_alpha();
 
     void show_depth_only();
 
@@ -35,7 +37,7 @@ struct image_viewer::impl {
 
     void show();
 
-    void render_color_obj(obj_name_type name);
+    void render_color_obj(obj_name_type name, float alpha = 1.0);
 
     // render depth with false color
     void render_depth_obj(obj_name_type name, float alpha = 1.0);

+ 29 - 5
src/network/binary_utility.hpp

@@ -116,6 +116,11 @@ struct data_type {
         *this = next;
     }
 
+    void shrink(size_t _size) {
+        assert(_size <= size);
+        size = _size;
+    }
+
     template<typename T>
     T *at(size_t pos) {
         static_assert(std::is_trivial_v<T>);
@@ -233,10 +238,21 @@ public:
         return ret;
     }
 
+    void read_data(void *data, size_t size) {
+        std::copy_n(cur_ptr, size, (uint8_t *) data);
+        cur_ptr += size;
+        assert(cur_ptr <= end_ptr());
+    }
+
     void read_data(const data_type &out) {
-        std::copy_n(cur_ptr, out.size, out.ptr);
-        cur_ptr += out.size;
+        read_data(data.start_ptr(), data.size);
+    }
+
+    std::string read_std_string(size_t size) {
+        auto ret = std::string((char *) cur_ptr, size);
+        cur_ptr += size;
         assert(cur_ptr <= end_ptr());
+        return ret;
     }
 
     data_type read_remain() {
@@ -277,12 +293,20 @@ public:
         }
     }
 
-    void write_data(const data_type &_data) {
-        std::copy_n(_data.start_ptr(), _data.size, cur_ptr);
-        cur_ptr += _data.size;
+    void write_data(const void *data, size_t size) {
+        std::copy_n((uint8_t *) data, size, cur_ptr);
+        cur_ptr += size;
         assert(cur_ptr <= end_ptr());
     }
 
+    void write_data(const data_type &_data) {
+        write_data(_data.start_ptr(), _data.size);
+    }
+
+    void write_value(const std::string &str) {
+        write_data(str.data(), str.length());
+    }
+
     void write_value(const data_type &_data) {
         write_data(_data);
     }

+ 58 - 26
src/render/impl/render_texture.cpp

@@ -13,6 +13,7 @@ namespace render_texture_impl {
     bool init_ok = false;
 
     using pg_type = std::unique_ptr<smart_program>;
+    pg_type pg_bw; // render black/white texture
     pg_type pg_rgb; // render rgb texture
     pg_type pg_rgba; // render rgba texture
     pg_type pg_rgb_d; // render rgb and depth texture
@@ -74,22 +75,46 @@ namespace render_texture_impl {
         glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, nullptr);
     }
 
+    // render bw texture
+    void ren_bw_only(const tex_render_info &info) {
+        auto &pg = pg_bw;
+        if (pg == nullptr) {
+            pg = std::unique_ptr<smart_program>(
+                    smart_program::create("tex_bw",
+                                          {{GL_VERTEX_SHADER,   "tex.vert"},
+                                           {GL_FRAGMENT_SHADER, "tex_bw.frag"}}));
+        }
+        assert(pg != nullptr);
+        pg->use();
+
+        pg->set_uniform_f("alpha", info.color.alpha);
+
+        glActiveTexture(GL_TEXTURE0 + 0);
+        glBindTexture(GL_TEXTURE_2D, info.color.id);
+        pg->set_uniform_i("c_tex", 0);
+
+        glDisable(GL_DEPTH_TEST);
+        config_buffers(info);
+        draw();
+    }
+
     // render rgb texture
     void ren_rgb_only(const tex_render_info &info) {
-        if (pg_rgb == nullptr) {
-            pg_rgb = std::unique_ptr<smart_program>(
+        auto &pg = pg_rgb;
+        if (pg == nullptr) {
+            pg = std::unique_ptr<smart_program>(
                     smart_program::create("tex_rgb",
                                           {{GL_VERTEX_SHADER,   "tex.vert"},
                                            {GL_FRAGMENT_SHADER, "tex_rgb.frag"}}));
         }
-        assert(pg_rgb != nullptr);
-        pg_rgb->use();
+        assert(pg != nullptr);
+        pg->use();
 
-        pg_rgb->set_uniform_f("alpha", info.color.alpha);
+        pg->set_uniform_f("alpha", info.color.alpha);
 
         glActiveTexture(GL_TEXTURE0 + 0);
         glBindTexture(GL_TEXTURE_2D, info.color.id);
-        pg_rgb->set_uniform_i("c_tex", 0);
+        pg->set_uniform_i("c_tex", 0);
 
         glDisable(GL_DEPTH_TEST);
         config_buffers(info);
@@ -97,20 +122,21 @@ namespace render_texture_impl {
     }
 
     void ren_rgba_only(const tex_render_info &info) {
-        if (pg_rgba == nullptr) {
-            pg_rgba = std::unique_ptr<smart_program>(
+        auto &pg = pg_rgba;
+        if (pg == nullptr) {
+            pg = std::unique_ptr<smart_program>(
                     smart_program::create("tex_rgba",
                                           {{GL_VERTEX_SHADER,   "tex.vert"},
                                            {GL_FRAGMENT_SHADER, "tex_rgba.frag"}}));
         }
-        assert(pg_rgba != nullptr);
-        pg_rgba->use();
+        assert(pg != nullptr);
+        pg->use();
 
-        pg_rgba->set_uniform_f("opacity", info.color.opacity);
+        pg->set_uniform_f("opacity", info.color.opacity);
 
         glActiveTexture(GL_TEXTURE0 + 0);
         glBindTexture(GL_TEXTURE_2D, info.color.id);
-        pg_rgba->set_uniform_i("c_tex", 0);
+        pg->set_uniform_i("c_tex", 0);
 
         glDisable(GL_DEPTH_TEST);
         config_buffers(info);
@@ -118,23 +144,24 @@ namespace render_texture_impl {
     }
 
     void ren_rgb_d(const tex_render_info &info) {
-        if (pg_rgb_d == nullptr) {
-            pg_rgb_d = std::unique_ptr<smart_program>(
+        auto &pg = pg_rgb_d;
+        if (pg == nullptr) {
+            pg = std::unique_ptr<smart_program>(
                     smart_program::create("tex_rgb_d",
                                           {{GL_VERTEX_SHADER,   "tex.vert"},
                                            {GL_FRAGMENT_SHADER, "tex_rgb_d.frag"}}));
         }
-        assert(pg_rgb_d != nullptr);
-        pg_rgb_d->use();
+        assert(pg != nullptr);
+        pg->use();
 
-        pg_rgb_d->set_uniform_f("alpha", info.color.alpha);
+        pg->set_uniform_f("alpha", info.color.alpha);
 
         glActiveTexture(GL_TEXTURE0 + 0);
         glBindTexture(GL_TEXTURE_2D, info.color.id);
-        pg_rgb_d->set_uniform_i("c_tex", 0);
+        pg->set_uniform_i("c_tex", 0);
         glActiveTexture(GL_TEXTURE0 + 1);
         glBindTexture(GL_TEXTURE_2D, info.depth.id);
-        pg_rgb_d->set_uniform_i("d_tex", 1);
+        pg->set_uniform_i("d_tex", 1);
 
         glEnable(GL_DEPTH_TEST);
         config_buffers(info);
@@ -142,23 +169,24 @@ namespace render_texture_impl {
     }
 
     void ren_nv12_only(const tex_render_info &info) {
-        if (pg_nv12 == nullptr) {
-            pg_nv12 = std::unique_ptr<smart_program>(
+        auto &pg = pg_nv12;
+        if (pg == nullptr) {
+            pg = std::unique_ptr<smart_program>(
                     smart_program::create("tex_nv12",
                                           {{GL_VERTEX_SHADER,   "tex.vert"},
                                            {GL_FRAGMENT_SHADER, "tex_nv12.frag"}}));
         }
-        assert(pg_nv12 != nullptr);
-        pg_nv12->use();
+        assert(pg != nullptr);
+        pg->use();
 
-        pg_nv12->set_uniform_f("alpha", info.color.alpha);
+        pg->set_uniform_f("alpha", info.color.alpha);
 
         glActiveTexture(GL_TEXTURE0 + 0);
         glBindTexture(GL_TEXTURE_2D, info.color.id);
-        pg_nv12->set_uniform_i("luma_tex", 0);
+        pg->set_uniform_i("luma_tex", 0);
         glActiveTexture(GL_TEXTURE0 + 1);
         glBindTexture(GL_TEXTURE_2D, info.color.id_ext[0]);
-        pg_nv12->set_uniform_i("chroma_tex", 1);
+        pg->set_uniform_i("chroma_tex", 1);
 
         glDisable(GL_DEPTH_TEST);
         config_buffers(info);
@@ -167,6 +195,10 @@ namespace render_texture_impl {
 
     void ren_c_only(const tex_render_info &info) {
         switch (info.color.fmt) {
+            case COLOR_BW: {
+                ren_bw_only(info);
+                break;
+            }
             case COLOR_RGB: {
                 ren_rgb_only(info);
                 break;

+ 2 - 0
src/render/impl/render_texturer_impl.h

@@ -5,6 +5,8 @@
 
 namespace render_texture_impl {
 
+    void ren_bw_only(const tex_render_info &info);
+
     void ren_rgb_only(const tex_render_info &info);
 
     void ren_rgba_only(const tex_render_info &info);

+ 10 - 0
src/render/impl/render_tools.cpp

@@ -1,3 +1,4 @@
+#include "core/image_utility_v2.h"
 #include "render/render_texture.h"
 #include "render/render_tools.h"
 #include "render/render_pc.h"
@@ -42,6 +43,14 @@ template void color_image_render::render_rgba<uchar4>(const image_info_u8c4 &, c
 void color_image_render::render_rgba(obj_name_type name, config_type conf) {
     auto img_type = OBJ_TYPE(name);
 
+    // TODO: ugly hacked
+    if (img_type == typeid(image_ptr)) {
+        auto img = OBJ_QUERY(image_ptr, name);
+        img_tex.upload(img, conf.stream);
+        render_tex(img->size(), conf);
+        return;
+    }
+
     auto impl_func = [&](auto V) {
         using T = std::remove_cvref_t<decltype(V)>;
         if (img_type == typeid(T)) {
@@ -80,6 +89,7 @@ void color_image_render::render_nv12(obj_name_type name, config_type conf) {
 
 void color_image_render::render(obj_name_type name, config_type conf) {
     switch (conf.fmt) {
+        case COLOR_BW:
         case COLOR_RGB:
         case COLOR_RGBA: {
             render_rgba(name, conf);

+ 18 - 0
src/render/impl/shader/tex_bw.frag

@@ -0,0 +1,18 @@
+#version 460
+
+uniform float alpha;
+
+uniform sampler2D c_tex;  // color texture
+
+in vec2 frag_uv;
+
+const vec3 black_color = { 0, 0, 0 };
+const vec3 white_color = { 1, 1, 1 };
+
+layout (location = 0) out vec4 frag_color;
+
+void main() {
+    float bw = texture(c_tex, frag_uv).r;
+    frag_color.rgb = (bw == 0) ? black_color : white_color;
+    frag_color.a = alpha;
+}

+ 4 - 0
src/render/render_tools.h

@@ -24,6 +24,10 @@ public:
 
     void render(obj_name_type name, config_type conf);
 
+// note: use render_rgba() instead
+//    void render_bw(const image_info_u8c1 &img, config_type conf);
+//    void render_bw(obj_name_type name, config_type conf);
+
     // T = uchar{1-4}
     template<typename T>
     void render_rgba(const image_info_type<T> &img, config_type conf);

+ 60 - 5
src/render/render_utility.h

@@ -3,6 +3,7 @@
 
 #include "core/cuda_helper.hpp"
 #include "core/image_utility.hpp"
+#include "core/image_utility_v2.h"
 #include "core/pc_utility.h"
 
 #include <glad/gl.h>
@@ -34,6 +35,20 @@ constexpr inline GLenum get_tex_format() {
     RET_ERROR_E;
 }
 
+constexpr inline GLenum get_tex_format(int type) {
+    switch (CV_MAT_CN(type)) {
+        // @formatter:off
+        case 1: { return GL_RED;  }
+        case 2: { return GL_RG;   }
+        case 3: { return GL_RGB;  }
+        case 4: { return GL_RGBA; }
+        // @formatter:on
+        default: {
+            RET_ERROR_E;
+        }
+    }
+}
+
 template<typename T>
 constexpr inline GLenum get_tex_type() {
     // @formatter:off
@@ -45,6 +60,17 @@ constexpr inline GLenum get_tex_type() {
     RET_ERROR_E;
 }
 
+constexpr inline GLenum get_tex_type(int type) {
+    switch (CV_MAT_DEPTH(type)) {
+        // @formatter:off
+        case CV_8U: { return GL_UNSIGNED_BYTE; }
+            // @formatter:on
+        default: {
+            RET_ERROR_E;
+        }
+    }
+}
+
 template<typename T>
 constexpr inline GLenum get_tex_internal_format() {
     // @formatter:off
@@ -56,11 +82,19 @@ constexpr inline GLenum get_tex_internal_format() {
     RET_ERROR_E;
 }
 
-enum color_format : uint8_t {
-    COLOR_RGB,
-    COLOR_RGBA,
-    COLOR_NV12
-};
+constexpr inline GLenum get_tex_internal_format(int type) { // OpenCV type
+    switch (type) {
+        // @formatter:off
+        case CV_8UC1: { return GL_R8;    }
+        case CV_8UC2: { return GL_RG8;   }
+        case CV_8UC3: { return GL_RGB8;  }
+        case CV_8UC4: { return GL_RGBA8; }
+        // @formatter:on
+        default: {
+            RET_ERROR_E;
+        }
+    }
+}
 
 struct simple_rect {
     GLfloat x, y;
@@ -90,6 +124,11 @@ public:
         upload_impl(pc_mem.ptr, pc->size_in_bytes(), MEM_HOST, stream);
     }
 
+    void upload(const image_ptr &img, smart_cuda_stream *stream) {
+        create(img->size_in_bytes());
+        upload_impl(img->memory_v1(stream), stream);
+    }
+
     template<typename T>
     void upload(const image_info_type<T> &img, smart_cuda_stream *stream) {
         create(sizeof(T) * img.size.area()); // TODO: create inside upload_impl
@@ -157,6 +196,11 @@ public:
         upload_impl(pbo_id, get_tex_format<T>(), get_tex_type<T>());
     }
 
+    void upload(GLuint pbo_id, cv::Size _size, int type) {
+        create(get_tex_internal_format(type), _size);
+        upload_impl(pbo_id, get_tex_format(type), get_tex_type(type));
+    }
+
     template<typename T>
     void upload(const image_info_type<T> &img, smart_cuda_stream *stream) {
         create(get_tex_internal_format<T>(), img.size);
@@ -170,6 +214,17 @@ public:
         }
     }
 
+    void upload(const image_ptr &img, smart_cuda_stream *stream) {
+        auto img_type = img->cv_type();
+        create(get_tex_internal_format(img_type), img->size());
+        if (CV_MAT_CN(img_type) == 3) {
+            img_pbo.upload(img, stream);
+            upload(img_pbo.id, size, img_type);
+        } else {
+            upload_impl(img->memory_v1(stream), stream);
+        }
+    }
+
 private:
 
     void deallocate();