From c1be487bb00f2e813212031d93fcebbfbd0da60b Mon Sep 17 00:00:00 2001 From: Robbie Harwood Date: Thu, 29 Aug 2019 11:13:41 -0400 Subject: [PATCH] Always buffer TCP data in __handle_recv() Refactor __handle_recv() to always create a BytesIO() object for TCP data. Linearize control flow for ease of debugging. Always apply length checks so that we don't have to wait for EOF in the multiple-recv case. Fixes a bug where we wouldn't return any data because we never received the EOF, or didn't receive it fast enough. Signed-off-by: Robbie Harwood (cherry picked from commit 7e2b1ab27b843c220fe301b74bab01ed61b0f59a) --- kdcproxy/__init__.py | 54 +++++++++++++++++++++++++------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/kdcproxy/__init__.py b/kdcproxy/__init__.py index 6526bc9..9bc7044 100644 --- a/kdcproxy/__init__.py +++ b/kdcproxy/__init__.py @@ -128,29 +128,37 @@ class Application: # length prefix. So add it. reply = struct.pack("!I", len(reply)) + reply return reply - else: - # TCP is a different story. The reply must be buffered - # until the full answer is accumulated. - buf = read_buffers.get(sock) - part = sock.recv(1048576) - if buf is None: - if len(part) > 4: - # got enough data in the initial package. Now check - # if we got the full package in the first run. - (length, ) = struct.unpack("!I", part[0:4]) - if length + 4 == len(part): - return part - read_buffers[sock] = buf = io.BytesIO() - - if part: - # data received, accumulate it in a buffer - buf.write(part) - return None - else: - # EOF received - read_buffers.pop(sock) - reply = buf.getvalue() - return reply + + # TCP is a different story. The reply must be buffered until the full + # answer is accumulated. + buf = read_buffers.get(sock) + if buf is None: + read_buffers[sock] = buf = io.BytesIO() + + part = sock.recv(1048576) + if not part: + # EOF received. Return any incomplete data we have on the theory + # that a decode error is more apparent than silent failure. The + # client will fail faster, at least. + read_buffers.pop(sock) + reply = buf.getvalue() + return reply + + # Data received, accumulate it in a buffer. + buf.write(part) + + reply = buf.getvalue() + if len(reply) < 4: + # We don't have the length yet. + return None + + # Got enough data to check if we have the full package. + (length, ) = struct.unpack("!I", reply[0:4]) + if length + 4 == len(reply): + read_buffers.pop(sock) + return reply + + return None def __filter_addr(self, addr): if addr[0] not in (socket.AF_INET, socket.AF_INET6):