forked from Cyvadra/a-julia-bert-client
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BertClient.jl
60 lines (47 loc) · 2.22 KB
/
BertClient.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
module BertClient
using JSON, ZMQ, PyCall
export tcp_server_address, port_client_input, port_server_output, bert_encode, bert_init
numpy = pyimport("numpy")
self_identity = "9636-0150-8aaa-"*repr(rand(Int8))
s2u(s::AbstractString) = join(["\\u"*string(Int(c), base=16, pad=4) for c in s]) # thanks Bogumit Kaminski for this func
self_context = ZMQ.Context()
self_sender = Socket(ZMQ.PUSH)
self_receiver = Socket(ZMQ.SUB)
function bert_init(tcp_server_address="127.0.0.1", port_client_input=5555, port_server_output=5556)
connect(self_sender,"tcp://$tcp_server_address:$port_client_input")
subscribe(self_receiver,self_identity)
connect(self_receiver,"tcp://$tcp_server_address:$port_server_output")
while self_receiver.rcvmore
ZMQ.recv(self_receiver)
end
end
bert_init()
function bert_encode( texts::Array, req_id = abs(rand(Int8)) )
texts_send = "[\"" * join( map(x->s2u(x),texts),"\",\"" ) * "\"]"
msg = ZMQ.Message("$(self_identity)")
ZMQ.send(self_sender, msg; more=true)
msg = ZMQ.Message(texts_send)
ZMQ.send(self_sender, msg; more=true)
msg = repr(hash(texts))
ZMQ.send( self_sender, Vector{UInt8}(msg)[1]; more=true )
msg = repr(length(texts))
ZMQ.send( self_sender, Vector{UInt8}(msg)[1]; more=false )
client_id = ZMQ.recv(self_receiver)
header_json = join( map(x->Char(x), ZMQ.recv(self_receiver)) )
res_content = ZMQ.recv(self_receiver)
req_id = ZMQ.recv(self_receiver)
header_json = JSON.Parser.parse(header_json)
res_content = convert(Array,res_content)
res_content = py"memoryview($res_content)"
res_content = numpy["frombuffer"](res_content,dtype="float32")
res_content = reshape(res_content,(header_json["shape"][1],header_json["shape"][2]))
res_array = Array{Float32,1}()
for i in 1:length(texts)
append!(res_array,res_content[i,:])
end
res_array
end
function bert_encode( text::AbstractString, req_id = abs(rand(Int8)) )
return bert_encode([text],req_id)
end
end