|
|
@@ -0,0 +1,64 @@
|
|
|
+import json
|
|
|
+import struct
|
|
|
+import numpy as np
|
|
|
+import zmq
|
|
|
+
|
|
|
+
|
|
|
+class ZmqServer(object):
|
|
|
+ def __init__(self, addr):
|
|
|
+ ctx = zmq.Context()
|
|
|
+ self.addr = addr
|
|
|
+ self.socket = ctx.socket(zmq.REP)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def decode_request(msg):
|
|
|
+ head_size = struct.unpack('>Q', msg[:8])[0]
|
|
|
+ msg = msg[8:]
|
|
|
+ head = json.loads(msg[:head_size].decode('utf-8'))
|
|
|
+ msg = msg[head_size:]
|
|
|
+ extra = head['extra']
|
|
|
+
|
|
|
+ img_list = []
|
|
|
+ for img_info in head['images']:
|
|
|
+ assert img_info['transfer'] == 'direct'
|
|
|
+ img_bytes = struct.unpack('>Q', msg[:8])[0]
|
|
|
+ msg = msg[8:]
|
|
|
+
|
|
|
+ img_dtype = img_info['dtype']
|
|
|
+ img_data = np.frombuffer(msg[:img_bytes], dtype=img_dtype)
|
|
|
+ msg = msg[img_bytes:]
|
|
|
+
|
|
|
+ img_shape = img_info['shape']
|
|
|
+ img = np.resize(img_data, img_shape)
|
|
|
+ img_list.append(img)
|
|
|
+
|
|
|
+ return img_list, extra
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def encode_reply(img_list, extra):
|
|
|
+ head = {'extra': extra, 'images': []}
|
|
|
+ for img in img_list:
|
|
|
+ head['images'].append({
|
|
|
+ 'dtype': img.dtype.str,
|
|
|
+ 'shape': img.shape,
|
|
|
+ 'transfer': 'direct'})
|
|
|
+
|
|
|
+ head_str = json.dumps(head).encode('utf-8')
|
|
|
+ msg = struct.pack('>Q', len(head_str))
|
|
|
+ msg += head_str
|
|
|
+
|
|
|
+ for img in img_list:
|
|
|
+ img_bytes = img.tobytes()
|
|
|
+ msg += struct.pack('>Q', len(img_bytes))
|
|
|
+ msg += img_bytes
|
|
|
+
|
|
|
+ return msg
|
|
|
+
|
|
|
+ def start(self, func):
|
|
|
+ self.socket.bind(self.addr)
|
|
|
+ while True:
|
|
|
+ req = self.socket.recv()
|
|
|
+ imgs, opts = self.decode_request(req)
|
|
|
+ imgs, opts = func(imgs, opts)
|
|
|
+ rep = self.encode_reply(imgs, opts)
|
|
|
+ self.socket.send(rep)
|