Prechádzať zdrojové kódy

Decoupled zmq_server

jcsyshc 1 rok pred
rodič
commit
296855bdbc

+ 64 - 0
src/ai/impl/python/zmq_server.py

@@ -0,0 +1,64 @@
+import json
+import struct
+import numpy as np
+import zmq
+
+
+class ZmqServer(object):
+    def __init__(self, addr):
+        ctx = zmq.Context()
+        self.addr = addr
+        self.socket = ctx.socket(zmq.REP)
+
+    @staticmethod
+    def decode_request(msg):
+        head_size = struct.unpack('>Q', msg[:8])[0]
+        msg = msg[8:]
+        head = json.loads(msg[:head_size].decode('utf-8'))
+        msg = msg[head_size:]
+        extra = head['extra']
+
+        img_list = []
+        for img_info in head['images']:
+            assert img_info['transfer'] == 'direct'
+            img_bytes = struct.unpack('>Q', msg[:8])[0]
+            msg = msg[8:]
+
+            img_dtype = img_info['dtype']
+            img_data = np.frombuffer(msg[:img_bytes], dtype=img_dtype)
+            msg = msg[img_bytes:]
+
+            img_shape = img_info['shape']
+            img = np.resize(img_data, img_shape)
+            img_list.append(img)
+
+        return img_list, extra
+
+    @staticmethod
+    def encode_reply(img_list, extra):
+        head = {'extra': extra, 'images': []}
+        for img in img_list:
+            head['images'].append({
+                'dtype': img.dtype.str,
+                'shape': img.shape,
+                'transfer': 'direct'})
+
+        head_str = json.dumps(head).encode('utf-8')
+        msg = struct.pack('>Q', len(head_str))
+        msg += head_str
+
+        for img in img_list:
+            img_bytes = img.tobytes()
+            msg += struct.pack('>Q', len(img_bytes))
+            msg += img_bytes
+
+        return msg
+
+    def start(self, func):
+        self.socket.bind(self.addr)
+        while True:
+            req = self.socket.recv()
+            imgs, opts = self.decode_request(req)
+            imgs, opts = func(imgs, opts)
+            rep = self.encode_reply(imgs, opts)
+            self.socket.send(rep)

+ 1 - 0
src/module/impl/image_viewer.cpp

@@ -93,6 +93,7 @@ void image_viewer::impl::render_color_obj(obj_name_type name, float alpha) {
     ren_conf.alpha = alpha;
     if (OBJ_TYPE(name) == typeid(image_ptr)) { // TODO: ugly hacked
         auto img = OBJ_QUERY(image_ptr, name);
+        if (img == nullptr) return;
         auto fmt = img->get_meta(META_COLOR_FMT);
         if (fmt.has_value()) {
             ren_conf.fmt = (color_format) *fmt;