Skip to content

Commit

Permalink
Merge pull request #157 from evonzee/opensb-support-norestart
Browse files Browse the repository at this point in the history
OpenSB Support - register filters always, use flags to control compression loop
  • Loading branch information
rubellyte authored Oct 25, 2024
2 parents a61e384 + 89bf87b commit 6600b0a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
30 changes: 12 additions & 18 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class StarryPyServer:
"""
def __init__(self, reader, writer, config, factory):
logger.debug("Initializing connection.")
self._reader = reader # read packets from client
self._writer = writer # writes packets to client
self._reader = ZstdFrameReader(reader, Direction.TO_SERVER) # read packets from client
self._writer = ZstdFrameWriter(writer) # writes packets to client
self._client_reader = None # read packets from server (acting as client)
self._client_writer = None # write packets to server
self.factory = factory
Expand All @@ -48,17 +48,13 @@ def __init__(self, reader, writer, config, factory):
self._client_read_future = None
self._server_write_future = None
self._client_write_future = None
self._expect_server_loop_death = False
logger.info("Received connection from {}".format(self.client_ip))

def start_zstd(self):
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
self._client_writer = ZstdFrameWriter(self._client_writer)
self._expect_server_loop_death = True
self._server_loop_future.cancel()
self._server_loop_future = asyncio.create_task(self.server_loop())
self._reader.enable_zstd()
self._client_reader.enable_zstd()
self._writer.enable_zstd(skip_packets=1) # skip this packet
self._client_writer.enable_zstd()
logger.info("Switched to zstd")


Expand Down Expand Up @@ -95,12 +91,8 @@ async def server_loop(self):
"{}: {}".format(err.__class__.__name__, err))
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
finally:
if not self._expect_server_loop_death:
logger.info("Server loop ended.")
self.die()
else:
logger.info("Restarting server loop for switch to zstd.")
self._expect_server_loop_death = False
logger.info("Server loop ended.")
self.die()

async def client_loop(self):
"""
Expand All @@ -109,9 +101,11 @@ async def client_loop(self):
:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
(reader, writer) = await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])

self._client_reader = ZstdFrameReader(reader, Direction.TO_CLIENT)
self._client_writer = ZstdFrameWriter(writer)

try:
while True:
Expand Down
17 changes: 12 additions & 5 deletions zstd_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def __init__(self, reader: asyncio.StreamReader, direction: Direction):
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
self.raw_reader = reader
self.direction = direction
self.zstd_enabled = False

def enable_zstd(self):
self.zstd_enabled = True

async def readexactly(self, count):
# print(f"Reading exactly {count} bytes")
Expand All @@ -31,11 +35,14 @@ async def read_from_network(self, target_count):
# print(f"Read {len(chunk)} bytes from network")
if not chunk:
raise asyncio.CancelledError("Connection closed")
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")
if not self.zstd_enabled:
self.outputbuffer.write(chunk)
else:
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")

class NonSeekableMemoryStream(io.RawIOBase):
def __init__(self):
Expand Down
11 changes: 10 additions & 1 deletion zstd_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import zstandard as zstd

class ZstdFrameWriter:
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
def __init__(self, raw_writer: asyncio.StreamWriter):
self.compressor = zstd.ZstdCompressor()
self.raw_writer = raw_writer
self.skip_packets = 0
self.zstd_enabled = False

def enable_zstd(self, skip_packets=0):
self.zstd_enabled = True
self.skip_packets = skip_packets

async def drain(self):
Expand All @@ -16,6 +21,10 @@ def close(self):
self.compressor = None

def write(self, data):

if not self.zstd_enabled:
self.raw_writer.write(data)
return

if self.skip_packets > 0:
self.skip_packets -= 1
Expand Down

0 comments on commit 6600b0a

Please sign in to comment.