diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..11583f9 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE +graft tests diff --git a/README.rst b/README.rst index 5f543d1..5b6a455 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,7 @@ python-proxy ============ -|made-with-python| |PyPI-version| |Hit-Count| |Downloads| +|made-with-python| |PyPI-version| |Hit-Count| |Downloads| |Downloads-month| |Downloads-week| .. |made-with-python| image:: https://img.shields.io/badge/Made%20with-Python-1f425f.svg :target: https://www.python.org/ @@ -11,8 +11,12 @@ python-proxy :target: https://pypi.python.org/pypi/pproxy/ .. |Downloads| image:: https://pepy.tech/badge/pproxy :target: https://pepy.tech/project/pproxy +.. |Downloads-month| image:: https://pepy.tech/badge/pproxy/month + :target: https://pepy.tech/project/pproxy +.. |Downloads-week| image:: https://pepy.tech/badge/pproxy/week + :target: https://pepy.tech/project/pproxy -HTTP/Socks4/Socks5/Shadowsocks/ShadowsocksR/SSH/Redirect/Pf TCP/UDP asynchronous tunnel proxy implemented in Python3 asyncio. +HTTP/HTTP2/HTTP3/Socks4/Socks5/Shadowsocks/SSH/Redirect/Pf/QUIC TCP/UDP asynchronous tunnel proxy implemented in Python3 asyncio. QuickStart ---------- @@ -73,8 +77,9 @@ Features - Proxy client/server for TCP/UDP. - Schedule (load balance) among remote servers. - Incoming traffic auto-detect. -- Tunnel/relay/backward-relay support. +- Tunnel/jump/backward-jump support. - Unix domain socket support. +- HTTP v2, HTTP v3 (QUIC) - User/password authentication support. - Filter/block hostname by regex patterns. - SSL/TLS client/server support. @@ -98,6 +103,10 @@ Protocols | http | | ✔ | | | httponly:// | | (get,post,etc) | | | | | (as client) | +-------------------+------------+------------+------------+------------+--------------+ +| http v2 (connect) | ✔ | ✔ | | | h2:// | ++-------------------+------------+------------+------------+------------+--------------+ +| http v3 (connect) | ✔ by UDP | ✔ by UDP | | | h3:// | ++-------------------+------------+------------+------------+------------+--------------+ | https | ✔ | ✔ | | | http+ssl:// | +-------------------+------------+------------+------------+------------+--------------+ | socks4 | ✔ | ✔ | | | socks4:// | @@ -112,8 +121,12 @@ Protocols +-------------------+------------+------------+------------+------------+--------------+ | shadowsocksR | ✔ | ✔ | | | ssr:// | +-------------------+------------+------------+------------+------------+--------------+ +| trojan | ✔ | ✔ | | | trojan:// | ++-------------------+------------+------------+------------+------------+--------------+ | ssh tunnel | | ✔ | | | ssh:// | +-------------------+------------+------------+------------+------------+--------------+ +| quic | ✔ by UDP | ✔ by UDP | ✔ | ✔ | http+quic:// | ++-------------------+------------+------------+------------+------------+--------------+ | iptables nat | ✔ | | | | redir:// | +-------------------+------------+------------+------------+------------+--------------+ | pfctl nat (macos) | ✔ | | | | pf:// | @@ -230,6 +243,8 @@ URI Syntax +----------+-----------------------------+ | ssr | shadowsocksr (SSR) protocol | +----------+-----------------------------+ + | trojan | trojan_ protocol | + +----------+-----------------------------+ | ssh | ssh client tunnel | +----------+-----------------------------+ | redir | redirect (iptables nat) | @@ -249,6 +264,8 @@ URI Syntax | direct | direct connection | +----------+-----------------------------+ +.. _trojan: https://trojan-gfw.github.io/trojan/protocol + - "http://" accepts GET/POST/CONNECT as server, sends CONNECT as client. "httponly://" sends "GET/POST" as client, works only on http traffic. - Valid schemes: http://, http+socks4+socks5://, http+ssl://, ss+secure://, http+socks5+ss:// @@ -360,7 +377,7 @@ URI Syntax - The username, colon ':', and the password -URIs can be joined by "__" to indicate tunneling by relay. For example, ss://1.2.3.4:1324__http://4.5.6.7:4321 make remote connection to the first shadowsocks proxy server, and then tunnel to the second http proxy server. +URIs can be joined by "__" to indicate tunneling by jump. For example, ss://1.2.3.4:1324__http://4.5.6.7:4321 make remote connection to the first shadowsocks proxy server, and then jump to the second http proxy server. .. _AEAD: http://shadowsocks.org/en/spec/AEAD-Ciphers.html @@ -553,9 +570,7 @@ Examples Make sure **pproxy** runs in root mode (sudo), otherwise it cannot redirect pf packet. -- Relay tunnel - - Relay tunnel example: +- Multiple jumps example .. code:: rst @@ -653,6 +668,12 @@ Examples Server connects to client_ip:8081 and waits for client proxy requests. The protocol http specified is just an example. It can be any protocol and cipher **pproxy** supports. The scheme "**in**" should exist in URI to inform **pproxy** that it is a backward proxy. + .. code:: rst + + $ pproxy -l http+in://jumpserver__http://client_ip:8081 + + It is a complicated example. Server connects to client_ip:8081 by jump http://jumpserver. The backward proxy works through jumps. + - SSH client tunnel SSH client tunnel support is enabled by installing additional library asyncssh_. After "pip3 install asyncssh", you can specify "**ssh**" as scheme to proxy via ssh client tunnel. @@ -669,10 +690,73 @@ Examples SSH connection known_hosts feature is disabled by default. +- SSH jump + + SSH jump is supported by using "__" concatenation + + .. code:: rst + + $ pproxy -r ssh://server1__ssh://server2__ssh://server3 + + First connection to server1 is made. Second, ssh connection to server2 is made from server1. Finally, connect to server3, and use server3 for proxying traffic. + +- SSH remote forward + + .. code:: rst + + $ pproxy -l ssh://server__tunnel://0.0.0.0:1234 -r tunnel://127.0.0.1:1234 + + TCP :1234 on remote server is forwarded to 127.0.0.1:1234 on local server + + .. code:: rst + + $ pproxy -l ssh://server1__ssh://server2__ss://0.0.0.0:1234 -r ss://server3:1234 + + It is a complicated example. SSH server2 is jumped from SSH server1, and ss://0.0.0.0:1234 on server2 is listened. Traffic is forwarded to ss://server3:1234. + +- Trojan protocol example + + Normally trojan:// should be used together with ssl://. You should specify the SSL crt/key file for ssl usage. A typical trojan server would be: + + .. code:: rst + + $ pproxy --ssl ssl.crt,ssl.key -l trojan+tunnel{localhost:80}+ssl://:443#yourpassword -vv + + If trojan password doesn't match, the tunnal{localhost:80} will be switched to. It looks exactly the same as a common HTTPS website. + +- QUIC protocol example + + QUIC is a UDP stream protocol used in HTTP/3. Library **aioquic** is required if you want to proxy via QUIC. + QUIC is listened on UDP port, but can handle TCP or UDP traffic. If you want to handle TCP traffic, you should use "-l quic+http" instead of "-ul quic+http". + + .. code:: rst + + $ pip3 install aioquic + $ pproxy --ssl ssl.crt,ssl.key -l quic+http://:1234 + + On the client: + + $ pproxy -r quic+http://server:1234 + + QUIC protocol can transfer a lot of TCP streams on one single UDP stream. If the connection number is hugh, QUIC can benefit by reducing TCP handshake time. + +- VPN Server Example + + You can run VPN server simply by installing pvpn (python vpn), a lightweight VPN server with pproxy tunnel feature. + + .. code:: rst + + $ pip3 install pvpn + Successfully installed pvpn-0.2.1 + $ pvpn -wg 9999 -r http://remote_server:remote_port + Serving on UDP :500 :4500... + Serving on UDP :9000 (WIREGUARD)... + TCP xx.xx.xx.xx:xx -> HTTP xx.xx.xx.xx:xx -> xx.xx.xx.xx:xx + Projects -------- -+ `python-esp `_ - Pure python VPN (IPSec,IKE,IKEv2,L2TP) -+ `shadowproxy `_ - Another awesome proxy implementation by guyingbo ++ `python-vpn `_ - VPN Server (IPSec,IKE,IKEv2,L2TP,WireGuard) in pure python ++ `shadowproxy `_ - Awesome python proxy implementation by guyingbo diff --git a/pproxy/__doc__.py b/pproxy/__doc__.py index 4a230fd..075c5b0 100644 --- a/pproxy/__doc__.py +++ b/pproxy/__doc__.py @@ -1,5 +1,4 @@ __title__ = "pproxy" -__version__ = "2.3.7" __license__ = "MIT" __description__ = "Proxy server that can tunnel among remote servers by regex rules." __keywords__ = "proxy socks http shadowsocks shadowsocksr ssr redirect pf tunnel cipher ssl udp" @@ -7,4 +6,14 @@ __email__ = "qianwenjie@gmail.com" __url__ = "https://github.com/qwj/python-proxy" +try: + from setuptools_scm import get_version + __version__ = get_version() +except Exception: + try: + from pkg_resources import get_distribution + __version__ = get_distribution('pproxy').version + except Exception: + __version__ = 'unknown' + __all__ = ['__version__', '__description__', '__url__'] diff --git a/pproxy/__init__.py b/pproxy/__init__.py index ce08ce9..0450bfa 100644 --- a/pproxy/__init__.py +++ b/pproxy/__init__.py @@ -1,6 +1,6 @@ from . import server -Connection = server.ProxyURI.compile_relay -DIRECT = server.ProxyURI.DIRECT -Server = server.ProxyURI.compile -Rule = server.ProxyURI.compile_rule +Connection = server.proxies_by_uri +Server = server.proxies_by_uri +Rule = server.compile_rule +DIRECT = server.DIRECT diff --git a/pproxy/proto.py b/pproxy/proto.py index b0f7632..8b0e328 100644 --- a/pproxy/proto.py +++ b/pproxy/proto.py @@ -3,6 +3,16 @@ HTTP_LINE = re.compile('([^ ]+) +(.+?) +(HTTP/[^ ]+)$') packstr = lambda s, n=1: len(s).to_bytes(n, 'big') + s +def netloc_split(loc, default_host=None, default_port=None): + ipv6 = re.fullmatch('\[([0-9a-fA-F:]*)\](?::(\d+)?)?', loc) + if ipv6: + host_name, port = ipv6.groups() + elif ':' in loc: + host_name, port = loc.rsplit(':', 1) + else: + host_name, port = loc, None + return host_name or default_host, int(port) if port else default_port + async def socks_address_stream(reader, n): if n in (1, 17): data = await reader.read_n(4) @@ -33,21 +43,21 @@ def name(self): return self.__class__.__name__.lower() def reuse(self): return False - def udp_parse(self, data, **kw): + def udp_accept(self, data, **kw): raise Exception(f'{self.name} don\'t support UDP server') def udp_connect(self, rauth, host_name, port, data, **kw): raise Exception(f'{self.name} don\'t support UDP client') - def udp_client(self, data): + def udp_unpack(self, data): return data - def udp_client2(self, host_name, port, data): + def udp_pack(self, host_name, port, data): return data async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): raise Exception(f'{self.name} don\'t support client') async def channel(self, reader, writer, stat_bytes, stat_conn): try: stat_conn(1) - while True: - data = await reader.read_() + while not reader.at_eof() and not writer.is_closing(): + data = await reader.read(65536) if not data: break if stat_bytes is None: @@ -64,23 +74,48 @@ async def channel(self, reader, writer, stat_bytes, stat_conn): class Direct(BaseProtocol): pass +class Trojan(BaseProtocol): + async def guess(self, reader, users, **kw): + header = await reader.read_w(56) + if users: + for user in users: + if hashlib.sha224(user).hexdigest().encode() == header: + return user + else: + if hashlib.sha224(b'').hexdigest().encode() == header: + return True + reader.rollback(header) + async def accept(self, reader, user, **kw): + assert await reader.read_n(2) == b'\x0d\x0a' + if (await reader.read_n(1))[0] != 1: + raise Exception('Connection closed') + host_name, port, _ = await socks_address_stream(reader, (await reader.read_n(1))[0]) + assert await reader.read_n(2) == b'\x0d\x0a' + return user, host_name, port + async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): + toauth = hashlib.sha224(rauth or b'').hexdigest().encode() + writer_remote.write(toauth + b'\x0d\x0a\x01\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big') + b'\x0d\x0a') + class SSR(BaseProtocol): - def correct_header(self, header, auth, **kw): - return auth and header == auth[:1] or not auth and header and header[0] in (1, 3, 4) - async def parse(self, header, reader, auth, authtable, **kw): - if auth: - if (await reader.read_n(len(auth)-1)) != auth[1:]: - raise Exception('Unauthorized SSR') - authtable.set_authed() - header = await reader.read_n(1) - host_name, port, data = await socks_address_stream(reader, header[0]) - return host_name, port + async def guess(self, reader, users, **kw): + if users: + header = await reader.read_w(max(len(i) for i in users)) + reader.rollback(header) + user = next(filter(lambda x: x == header[:len(x)], users), None) + if user is None: + return + await reader.read_n(len(user)) + return user + header = await reader.read_w(1) + reader.rollback(header) + return header[0] in (1, 3, 4, 17, 19, 20) + async def accept(self, reader, user, **kw): + host_name, port, data = await socks_address_stream(reader, (await reader.read_n(1))[0]) + return user, host_name, port async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): writer_remote.write(rauth + b'\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big')) -class SS(BaseProtocol): - def correct_header(self, header, auth, **kw): - return auth and header == auth[:1] or not auth and header and header[0] in (1, 3, 4, 17, 19, 20) +class SS(SSR): def patch_ota_reader(self, cipher, reader): chunk_id, data_len, _buffer = 0, None, bytearray() def decrypt(s): @@ -115,12 +150,8 @@ def write(data, o=writer.write): chunk_id += 1 return o(len(data).to_bytes(2, 'big') + checksum[:10] + data) writer.write = write - async def parse(self, header, reader, auth, authtable, reader_cipher, **kw): - if auth: - if (await reader.read_n(len(auth)-1)) != auth[1:]: - raise Exception('Unauthorized SS') - authtable.set_authed() - header = await reader.read_n(1) + async def accept(self, reader, user, reader_cipher, **kw): + header = await reader.read_n(1) ota = (header[0] & 0x10 == 0x10) host_name, port, data = await socks_address_stream(reader, header[0]) assert ota or not reader_cipher or not reader_cipher.ota, 'SS client must support OTA' @@ -128,7 +159,7 @@ async def parse(self, header, reader, auth, authtable, reader_cipher, **kw): checksum = hmac.new(reader_cipher.iv+reader_cipher.key, header+data, hashlib.sha1).digest() assert checksum[:10] == await reader.read_n(10), 'Unknown OTA checksum' self.patch_ota_reader(reader_cipher, reader) - return host_name, port + return user, host_name, port async def connect(self, reader_remote, writer_remote, rauth, host_name, port, writer_cipher_r, **kw): writer_remote.write(rauth) if writer_cipher_r and writer_cipher_r.ota: @@ -138,21 +169,25 @@ async def connect(self, reader_remote, writer_remote, rauth, host_name, port, wr self.patch_ota_writer(writer_cipher_r, writer_remote) else: writer_remote.write(b'\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big')) - def udp_parse(self, data, auth, **kw): + def udp_accept(self, data, users, **kw): reader = io.BytesIO(data) - if auth and reader.read(len(auth)) != auth: - return + user = True + if users: + user = next(filter(lambda i: data[:len(i)]==i, users), None) + if user is None: + return + reader.read(len(user)) n = reader.read(1)[0] if n not in (1, 3, 4): return host_name, port = socks_address(reader, n) - return host_name, port, reader.read() - def udp_client(self, data): + return user, host_name, port, reader.read() + def udp_unpack(self, data): reader = io.BytesIO(data) n = reader.read(1)[0] host_name, port = socks_address(reader, n) return reader.read() - def udp_client2(self, host_name, port, data): + def udp_pack(self, host_name, port, data): try: return b'\x01' + socket.inet_aton(host_name) + port.to_bytes(2, 'big') + data except Exception: @@ -162,19 +197,25 @@ def udp_connect(self, rauth, host_name, port, data, **kw): return rauth + b'\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big') + data class Socks4(BaseProtocol): - def correct_header(self, header, **kw): - return header == b'\x04' - async def parse(self, reader, writer, auth, authtable, **kw): + async def guess(self, reader, **kw): + header = await reader.read_w(1) + if header == b'\x04': + return True + reader.rollback(header) + async def accept(self, reader, user, writer, users, authtable, **kw): assert await reader.read_n(1) == b'\x01' port = int.from_bytes(await reader.read_n(2), 'big') ip = await reader.read_n(4) userid = (await reader.read_until(b'\x00'))[:-1] - if auth: - if auth != userid and not authtable.authed(): - raise Exception(f'Unauthorized SOCKS {auth}') - authtable.set_authed() + user = authtable.authed() + if users: + if userid in users: + user = userid + elif not user: + raise Exception(f'Unauthorized SOCKS {userid}') + authtable.set_authed(user) writer.write(b'\x00\x5a' + port.to_bytes(2, 'big') + ip) - return socket.inet_ntoa(ip), port + return user, socket.inet_ntoa(ip), port async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): ip = socket.inet_aton((await asyncio.get_event_loop().getaddrinfo(host_name, port, family=socket.AF_INET))[0][4][0]) writer_remote.write(b'\x04\x01' + port.to_bytes(2, 'big') + ip + rauth + b'\x00') @@ -182,33 +223,50 @@ async def connect(self, reader_remote, writer_remote, rauth, host_name, port, ** await reader_remote.read_n(6) class Socks5(BaseProtocol): - def correct_header(self, header, **kw): - return header == b'\x05' - async def parse(self, reader, writer, auth, authtable, **kw): + async def guess(self, reader, **kw): + header = await reader.read_w(1) + if header == b'\x05': + return True + reader.rollback(header) + async def accept(self, reader, user, writer, users, authtable, **kw): methods = await reader.read_n((await reader.read_n(1))[0]) - if auth and (b'\x00' not in methods or not authtable.authed()): + user = authtable.authed() + if users and (not user or b'\x00' not in methods): + if b'\x02' not in methods: + raise Exception(f'Unauthorized SOCKS') writer.write(b'\x05\x02') assert (await reader.read_n(1))[0] == 1, 'Unknown SOCKS auth' u = await reader.read_n((await reader.read_n(1))[0]) p = await reader.read_n((await reader.read_n(1))[0]) - if u+b':'+p != auth: + user = u+b':'+p + if user not in users: raise Exception(f'Unauthorized SOCKS {u}:{p}') writer.write(b'\x01\x00') + elif users and not user: + raise Exception(f'Unauthorized SOCKS') else: writer.write(b'\x05\x00') - if auth: - authtable.set_authed() - assert (await reader.read_n(3)) == b'\x05\x01\x00', 'Unknown SOCKS protocol' + if users: + authtable.set_authed(user) + assert await reader.read_n(3) == b'\x05\x01\x00', 'Unknown SOCKS protocol' header = await reader.read_n(1) host_name, port, data = await socks_address_stream(reader, header[0]) writer.write(b'\x05\x00\x00' + header + data) - return host_name, port + return user, host_name, port async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): - writer_remote.write((b'\x05\x01\x02\x01' + b''.join(packstr(i) for i in rauth.split(b':', 1)) if rauth else b'\x05\x01\x00') + b'\x05\x01\x00\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big')) - await reader_remote.read_until(b'\x00\x05\x00\x00') + if rauth: + writer_remote.write(b'\x05\x01\x02') + assert await reader_remote.read_n(2) == b'\x05\x02' + writer_remote.write(b'\x01' + b''.join(packstr(i) for i in rauth.split(b':', 1))) + assert await reader_remote.read_n(2) == b'\x01\x00', 'Unknown SOCKS auth' + else: + writer_remote.write(b'\x05\x01\x00') + assert await reader_remote.read_n(2) == b'\x05\x00' + writer_remote.write(b'\x05\x01\x00\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big')) + assert await reader_remote.read_n(3) == b'\x05\x00\x00' header = (await reader_remote.read_n(1))[0] await reader_remote.read_n(6 if header == 1 else (18 if header == 4 else (await reader_remote.read_n(1))[0]+2)) - def udp_parse(self, data, **kw): + def udp_accept(self, data, **kw): reader = io.BytesIO(data) if reader.read(3) != b'\x00\x00\x00': return @@ -216,58 +274,68 @@ def udp_parse(self, data, **kw): if n not in (1, 3, 4): return host_name, port = socks_address(reader, n) - return host_name, port, reader.read() + return True, host_name, port, reader.read() def udp_connect(self, rauth, host_name, port, data, **kw): return b'\x00\x00\x00\x03' + packstr(host_name.encode()) + port.to_bytes(2, 'big') + data class HTTP(BaseProtocol): - def correct_header(self, header, **kw): - return header and header.isalpha() - async def parse(self, header, reader, writer, auth, authtable, httpget=None, **kw): - lines = header + await reader.read_until(b'\r\n\r\n') + async def guess(self, reader, **kw): + header = await reader.read_w(4) + reader.rollback(header) + return header in (b'GET ', b'HEAD', b'POST', b'PUT ', b'DELE', b'CONN', b'OPTI', b'TRAC', b'PATC') + async def accept(self, reader, user, writer, **kw): + lines = await reader.read_until(b'\r\n\r\n') headers = lines[:-4].decode().split('\r\n') method, path, ver = HTTP_LINE.match(headers.pop(0)).groups() lines = '\r\n'.join(i for i in headers if not i.startswith('Proxy-')) headers = dict(i.split(': ', 1) for i in headers if ': ' in i) + async def reply(code, message, body=None, wait=False): + writer.write(message) + if body: + writer.write(body) + if wait: + await writer.drain() + return await self.http_accept(user, method, path, None, ver, lines, headers.get('Host', ''), headers.get('Proxy-Authorization'), reply, **kw) + async def http_accept(self, user, method, path, authority, ver, lines, host, pauth, reply, authtable, users, httpget=None, **kw): url = urllib.parse.urlparse(path) - if method == 'GET' and not url.hostname and httpget: - for path, text in httpget.items(): - if url.path == path: - authtable.set_authed() - if type(text) is str: - text = (text % dict(host=headers["Host"])).encode() - writer.write(f'{ver} 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nCache-Control: max-age=900\r\nContent-Length: {len(text)}\r\n\r\n'.encode() + text) - await writer.drain() - raise Exception('Connection closed') + if method == 'GET' and not url.hostname: + for path, text in (httpget.items() if httpget else ()): + if path == url.path: + user = next(filter(lambda x: x.decode()==url.query, users), None) if users else True + if user: + if users: + authtable.set_authed(user) + if type(text) is str: + text = (text % dict(host=host)).encode() + await reply(200, f'{ver} 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nCache-Control: max-age=900\r\nContent-Length: {len(text)}\r\n\r\n'.encode(), text, True) + raise Exception('Connection closed') raise Exception(f'404 {method} {url.path}') - if auth: - pauth = headers.get('Proxy-Authorization', None) - httpauth = 'Basic ' + base64.b64encode(auth).decode() - if not authtable.authed() and pauth != httpauth: - writer.write(f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode()) - raise Exception('Unauthorized HTTP') - authtable.set_authed() + if users: + user = authtable.authed() + if not user: + user = next(filter(lambda i: ('Basic '+base64.b64encode(i).decode()) == pauth, users), None) + if user is None: + await reply(407, f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode(), wait=True) + raise Exception('Unauthorized HTTP') + authtable.set_authed(user) if method == 'CONNECT': - host_name, port = path.rsplit(':', 1) - port = int(port) - return host_name, port, f'{ver} 200 OK\r\nConnection: close\r\n\r\n'.encode() + host_name, port = netloc_split(authority or path) + return user, host_name, port, lambda writer: reply(200, f'{ver} 200 OK\r\nConnection: close\r\n\r\n'.encode()) else: - url = urllib.parse.urlparse(path) - if ':' in url.netloc: - host_name, port = url.netloc.rsplit(':', 1) - port = int(port) - else: - host_name, port = url.netloc, 80 + host_name, port = netloc_split(url.netloc or host, default_port=80) newpath = url._replace(netloc='', scheme='').geturl() - return host_name, port, b'', f'{method} {newpath} {ver}\r\n{lines}\r\n\r\n'.encode() + async def connected(writer): + writer.write(f'{method} {newpath} {ver}\r\n{lines}\r\n\r\n'.encode()) + return True + return user, host_name, port, connected async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): writer_remote.write(f'CONNECT {host_name}:{port} HTTP/1.1\r\nHost: {myhost}'.encode() + (b'\r\nProxy-Authorization: Basic '+base64.b64encode(rauth) if rauth else b'') + b'\r\n\r\n') await reader_remote.read_until(b'\r\n\r\n') async def http_channel(self, reader, writer, stat_bytes, stat_conn): try: stat_conn(1) - while True: - data = await reader.read_() + while not reader.at_eof() and not writer.is_closing(): + data = await reader.read(65536) if not data: break if b'\r\n' in data and HTTP_LINE.match(data.split(b'\r\n', 1)[0].decode()): @@ -310,29 +378,46 @@ def write(data, o=writer_remote.write): return o(data) writer_remote.write = write +class H2(HTTP): + async def guess(self, reader, **kw): + return True + async def accept(self, reader, user, writer, **kw): + if not writer.headers.done(): + await writer.headers + headers = writer.headers.result() + headers = {i.decode().lower():j.decode() for i,j in headers} + lines = '\r\n'.join(i for i in headers if not i.startswith('proxy-') and not i.startswith(':')) + async def reply(code, message, body=None, wait=False): + writer.send_headers(((':status', str(code)),)) + if body: + writer.write(body) + if wait: + await writer.drain() + return await self.http_accept(user, headers[':method'], headers[':path'], headers[':authority'], '2.0', lines, '', headers.get('proxy-authorization'), reply, **kw) + async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): + headers = [(':method', 'CONNECT'), (':scheme', 'https'), (':path', '/'), + (':authority', f'{host_name}:{port}')] + if rauth: + headers.append(('proxy-authorization', 'Basic '+base64.b64encode(rauth))) + writer_remote.send_headers(headers) + +class H3(H2): + pass + class SSH(BaseProtocol): async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): pass class Transparent(BaseProtocol): - def correct_header(self, header, auth, sock, **kw): + async def guess(self, reader, sock, **kw): remote = self.query_remote(sock) - if remote is None or sock.getsockname() == remote: - return False - return auth and header == auth[:1] or not auth - async def parse(self, reader, auth, authtable, sock, **kw): - if auth: - if (await reader.read_n(len(auth)-1)) != auth[1:]: - raise Exception(f'Unauthorized {self.name}') - authtable.set_authed() + return remote is not None and (sock is None or sock.getsockname() != remote) + async def accept(self, reader, user, sock, **kw): remote = self.query_remote(sock) - return remote[0], remote[1] - def udp_parse(self, data, auth, sock, **kw): - reader = io.BytesIO(data) - if auth and reader.read(len(auth)) != auth: - return + return user, remote[0], remote[1] + def udp_accept(self, data, sock, **kw): remote = self.query_remote(sock) - return remote[0], remote[1], reader.read() + return True, remote[0], remote[1], data SO_ORIGINAL_DST = 80 SOL_IPV6 = 41 @@ -371,19 +456,18 @@ class Tunnel(Transparent): def query_remote(self, sock): if not self.param: return 'tunnel', 0 - host, _, port = self.param.partition(':') - dst = sock.getsockname() - host = host or dst[0] - port = int(port) if port else dst[1] - return host, port + dst = sock.getsockname() if sock else (None, None) + return netloc_split(self.param, dst[0], dst[1]) async def connect(self, reader_remote, writer_remote, rauth, host_name, port, **kw): - writer_remote.write(rauth) + pass def udp_connect(self, rauth, host_name, port, data, **kw): - return rauth + data + return data class WS(BaseProtocol): - def correct_header(self, header, **kw): - return header and header.isalpha() + async def guess(self, reader, **kw): + header = await reader.read_w(4) + reader.rollback(header) + return reader == b'GET ' def patch_ws_stream(self, reader, writer, masked=False): data_len, mask_key, _buffer = None, None, bytearray() def feed_data(s, o=reader.feed_data): @@ -424,20 +508,22 @@ def write(data, o=writer.write): else: return o(b'\x02' + (bytes([data_len]) if data_len < 126 else b'\x7e'+data_len.to_bytes(2, 'big') if data_len < 65536 else b'\x7f'+data_len.to_bytes(4, 'big')) + data) writer.write = write - async def parse(self, header, reader, writer, auth, authtable, sock, **kw): - lines = header + await reader.read_until(b'\r\n\r\n') + async def accept(self, reader, user, writer, users, authtable, sock, **kw): + lines = await reader.read_until(b'\r\n\r\n') headers = lines[:-4].decode().split('\r\n') method, path, ver = HTTP_LINE.match(headers.pop(0)).groups() lines = '\r\n'.join(i for i in headers if not i.startswith('Proxy-')) headers = dict(i.split(': ', 1) for i in headers if ': ' in i) url = urllib.parse.urlparse(path) - if auth: + if users: pauth = headers.get('Proxy-Authorization', None) - httpauth = 'Basic ' + base64.b64encode(auth).decode() - if not authtable.authed() and pauth != httpauth: - writer.write(f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode()) - raise Exception('Unauthorized WebSocket') - authtable.set_authed() + user = authtable.authed() + if not user: + user = next(filter(lambda i: ('Basic '+base64.b64encode(i).decode()) == pauth, users), None) + if user is None: + writer.write(f'{ver} 407 Proxy Authentication Required\r\nConnection: close\r\nProxy-Authenticate: Basic realm="simple"\r\n\r\n'.encode()) + raise Exception('Unauthorized WebSocket') + authtable.set_authed(user) if method != 'GET': raise Exception(f'Unsupported method {method}') if headers.get('Sec-WebSocket-Key', None) is None: @@ -448,11 +534,9 @@ async def parse(self, header, reader, writer, auth, authtable, sock, **kw): self.patch_ws_stream(reader, writer, False) if not self.param: return 'tunnel', 0 - host, _, port = self.param.partition(':') dst = sock.getsockname() - host = host or dst[0] - port = int(port) if port else dst[1] - return host, port + host, port = netloc_split(self.param, dst[0], dst[1]) + return user, host, port async def connect(self, reader_remote, writer_remote, rauth, host_name, port, myhost, **kw): seckey = base64.b64encode(os.urandom(16)).decode() writer_remote.write(f'GET / HTTP/1.1\r\nHost: {myhost}\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {seckey}\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Version: 13'.encode() + (b'\r\nProxy-Authorization: Basic '+base64.b64encode(rauth) if rauth else b'') + b'\r\n\r\n') @@ -463,110 +547,27 @@ class Echo(Transparent): def query_remote(self, sock): return 'echo', 0 -class Pack(BaseProtocol): - def reuse(self): - return True - def get_handler(self, reader, writer, verbose, tcp_handler=None, udp_handler=None): - class Handler: - def __init__(self): - self.sessions = {} - self.udpmap = {} - self.closed = False - self.ready = False - asyncio.ensure_future(self.reader_handler()) - def __bool__(self): - return not self.closed - async def reader_handler(self): - try: - while True: - try: - header = (await reader.readexactly(1))[0] - except Exception: - raise Exception('Connection closed') - sid = await reader.read_n(8) - if header in (0x01, 0x03, 0x04, 0x11, 0x13, 0x14): - host_name, port, _ = await socks_address_stream(reader, header) - if (header & 0x10 == 0) and tcp_handler: - remote_reader, remote_writer = self.get_streams(sid) - asyncio.ensure_future(tcp_handler(remote_reader, remote_writer, host_name, port)) - elif (header & 0x10 != 0) and udp_handler: - self.get_datagram(sid, host_name, port) - elif header in (0x20, 0x30): - datalen = int.from_bytes(await reader.read_n(2), 'big') - data = await reader.read_n(datalen) - if header == 0x20 and sid in self.sessions: - self.sessions[sid].feed_data(data) - elif header == 0x30 and sid in self.udpmap and udp_handler: - host_name, port, sendto = self.udpmap[sid] - asyncio.ensure_future(udp_handler(sendto, data, host_name, port, sid)) - elif header == 0x40: - if sid in self.sessions: - self.sessions.pop(sid).feed_eof() - else: - raise Exception(f'Unknown header {header}') - except Exception as ex: - if not isinstance(ex, asyncio.TimeoutError) and not str(ex).startswith('Connection closed'): - verbose(f'{str(ex) or "Unsupported protocol"}') - finally: - for sid, session in self.sessions.items(): - session.feed_eof() - try: writer.close() - except Exception: pass - self.closed = True - def get_streams(self, sid): - self.sessions[sid] = asyncio.StreamReader() - class Writer(): - def write(self, data): - while len(data) >= 32*1024: - writer.write(b'\x20'+sid+(32*1024).to_bytes(2,'big')+data[:32*1024]) - data = data[32*1024:] - if data: - writer.write(b'\x20'+sid+len(data).to_bytes(2,'big')+data) - def drain(self): - return writer.drain() - def close(self): - if not writer.transport.is_closing(): - writer.write(b'\x40'+sid) - return self.sessions[sid], Writer() - def connect(self, host_name, port): - self.ready = True - sid = os.urandom(8) - writer.write(b'\x03' + sid + packstr(host_name.encode()) + port.to_bytes(2, 'big')) - return self.get_streams(sid) - def get_datagram(self, sid, host_name, port): - def sendto(data): - if data: - writer.write(b'\x30'+sid+len(data).to_bytes(2,'big')+data) - self.udpmap[sid] = (host_name, port, sendto) - return self.udpmap[sid] - writer.get_extra_info('socket').setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - return Handler() - -async def parse(protos, reader, **kw): - proto = next(filter(lambda p: p.correct_header(None, **kw), protos), None) - if proto is None: +async def accept(protos, reader, **kw): + for proto in protos: try: - header = await reader.read_n(1) + user = await proto.guess(reader, **kw) except Exception: raise Exception('Connection closed') - proto = next(filter(lambda p: p.correct_header(header, **kw), protos), None) - else: - header = None - if proto is not None: - ret = await proto.parse(header=header, reader=reader, **kw) - while len(ret) < 4: - ret += (b'',) - return (proto,) + ret - raise Exception(f'Unsupported protocol {header}') + if user: + ret = await proto.accept(reader, user, **kw) + while len(ret) < 4: + ret += (None,) + return (proto,) + ret + raise Exception(f'Unsupported protocol') -def udp_parse(protos, data, **kw): +def udp_accept(protos, data, **kw): for proto in protos: - ret = proto.udp_parse(data, **kw) + ret = proto.udp_accept(data, **kw) if ret: return (proto,) + ret raise Exception(f'Unsupported protocol {data[:10]}') -MAPPINGS = dict(direct=Direct, http=HTTP, httponly=HTTPOnly, ssh=SSH, socks5=Socks5, socks4=Socks4, socks=Socks5, ss=SS, ssr=SSR, redir=Redir, pf=Pf, tunnel=Tunnel, echo=Echo, pack=Pack, ws=WS, ssl='', secure='') +MAPPINGS = dict(direct=Direct, http=HTTP, httponly=HTTPOnly, ssh=SSH, socks5=Socks5, socks4=Socks4, socks=Socks5, ss=SS, ssr=SSR, redir=Redir, pf=Pf, tunnel=Tunnel, echo=Echo, ws=WS, trojan=Trojan, h2=H2, h3=H3, ssl='', secure='', quic='') MAPPINGS['in'] = '' def get_protos(rawprotos): @@ -609,14 +610,15 @@ def close(self): def _force_close(self, exc): if not self.closed: (verbose or print)(f'{exc} from {writer.get_extra_info("peername")[0]}') + ssl._app_transport._closed = True self.close() def abort(self): self.close() ssl.connection_made(Transport()) async def channel(): try: - while True: - data = await reader.read_() + while not reader.at_eof() and not ssl._app_transport._closed: + data = await reader.read(65536) if not data: break ssl.data_received(data) @@ -632,6 +634,9 @@ def write(self, data): ssl._app_transport.write(data) def drain(self): return writer.drain() + def is_closing(self): + return ssl._app_transport._closed def close(self): - ssl._app_transport.close() + if not ssl._app_transport._closed: + ssl._app_transport.close() return ssl_reader, Writer() diff --git a/pproxy/server.py b/pproxy/server.py index 8bba08a..5b75331 100644 --- a/pproxy/server.py +++ b/pproxy/server.py @@ -2,24 +2,32 @@ from . import proto from .__doc__ import * -SOCKET_TIMEOUT = 300 -PACKET_SIZE = 65536 +SOCKET_TIMEOUT = 60 UDP_LIMIT = 30 DUMMY = lambda s: s -asyncio.StreamReader.read_ = lambda self: self.read(PACKET_SIZE) -asyncio.StreamReader.read_n = lambda self, n: asyncio.wait_for(self.readexactly(n), timeout=SOCKET_TIMEOUT) -asyncio.StreamReader.read_until = lambda self, s: asyncio.wait_for(self.readuntil(s), timeout=SOCKET_TIMEOUT) +def patch_StreamReader(c=asyncio.StreamReader): + c.read_w = lambda self, n: asyncio.wait_for(self.read(n), timeout=SOCKET_TIMEOUT) + c.read_n = lambda self, n: asyncio.wait_for(self.readexactly(n), timeout=SOCKET_TIMEOUT) + c.read_until = lambda self, s: asyncio.wait_for(self.readuntil(s), timeout=SOCKET_TIMEOUT) + c.rollback = lambda self, s: self._buffer.__setitem__(slice(0, 0), s) +def patch_StreamWriter(c=asyncio.StreamWriter): + c.is_closing = lambda self: self._transport.is_closing() # Python 3.6 fix +patch_StreamReader() +patch_StreamWriter() class AuthTable(object): _auth = {} + _user = {} def __init__(self, remote_ip, authtime): self.remote_ip = remote_ip self.authtime = authtime def authed(self): - return time.time() - self._auth.get(self.remote_ip, 0) <= self.authtime - def set_authed(self): + if time.time() - self._auth.get(self.remote_ip, 0) <= self.authtime: + return self._user[self.remote_ip] + def set_authed(self, user): self._auth[self.remote_ip] = time.time() + self._user[self.remote_ip] = user async def prepare_ciphers(cipher, reader, writer, bind=None, server_side=True): if cipher: @@ -35,7 +43,7 @@ async def prepare_ciphers(cipher, reader, writer, bind=None, server_side=True): return None, None def schedule(rserver, salgorithm, host_name, port): - filter_cond = lambda o: o.alive and (not o.match or o.match(host_name) or o.match(str(port))) + filter_cond = lambda o: o.alive and o.match_rule(host_name, port) if salgorithm == 'fa': return next(filter(filter_cond, rserver), None) elif salgorithm == 'rr': @@ -47,22 +55,23 @@ def schedule(rserver, salgorithm, host_name, port): filters = [i for i in rserver if filter_cond(i)] return random.choice(filters) if filters else None elif salgorithm == 'lc': - return min(filter(filter_cond, rserver), default=None, key=lambda i: i.total) + return min(filter(filter_cond, rserver), default=None, key=lambda i: i.connections) else: raise Exception('Unknown scheduling algorithm') #Unreachable -async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, sslserver, debug=0, authtime=86400*30, block=None, salgorithm='fa', verbose=DUMMY, modstat=lambda r,h:lambda i:DUMMY, **kwargs): +async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, sslserver, debug=0, authtime=86400*30, block=None, salgorithm='fa', verbose=DUMMY, modstat=lambda u,r,h:lambda i:DUMMY, **kwargs): try: reader, writer = proto.sslwrap(reader, writer, sslserver, True, None, verbose) if unix: remote_ip, server_ip, remote_text = 'local', None, 'unix_local' else: - remote_ip, remote_port, *_ = writer.get_extra_info('peername') + peername = writer.get_extra_info('peername') + remote_ip, remote_port, *_ = peername if peername else ('unknow_remote_ip','unknow_remote_port') server_ip = writer.get_extra_info('sockname')[0] remote_text = f'{remote_ip}:{remote_port}' local_addr = None if server_ip in ('127.0.0.1', '::1', None) else (server_ip, 0) reader_cipher, _ = await prepare_ciphers(cipher, reader, writer, server_side=False) - lproto, host_name, port, lbuf, rbuf = await proto.parse(protos, reader=reader, writer=writer, authtable=AuthTable(remote_ip, authtime), reader_cipher=reader_cipher, sock=writer.get_extra_info('socket'), **kwargs) + lproto, user, host_name, port, client_connected = await proto.accept(protos, reader=reader, writer=writer, authtable=AuthTable(remote_ip, authtime), reader_cipher=reader_cipher, sock=writer.get_extra_info('socket'), **kwargs) if host_name == 'echo': asyncio.ensure_future(lproto.channel(reader, writer, DUMMY, DUMMY)) elif host_name == 'empty': @@ -70,7 +79,7 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s elif block and block(host_name): raise Exception('BLOCK ' + host_name) else: - roption = schedule(rserver, salgorithm, host_name, port) or ProxyURI.DIRECT + roption = schedule(rserver, salgorithm, host_name, port) or DIRECT verbose(f'{lproto.name} {remote_text}{roption.logtext(host_name, port)}') try: reader_remote, writer_remote = await roption.open_connection(host_name, port, local_addr, lbind) @@ -78,13 +87,12 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s raise Exception(f'Connection timeout {roption.bind}') try: reader_remote, writer_remote = await roption.prepare_connection(reader_remote, writer_remote, host_name, port) - writer.write(lbuf) - writer_remote.write(rbuf) + use_http = (await client_connected(writer_remote)) if client_connected else None except Exception: writer_remote.close() raise Exception('Unknown remote protocol') - m = modstat(remote_ip, host_name) - lchannel = lproto.http_channel if rbuf else lproto.channel + m = modstat(user, remote_ip, host_name) + lchannel = lproto.http_channel if use_http else lproto.channel asyncio.ensure_future(lproto.channel(reader_remote, writer, m(2+roption.direct), m(4+roption.direct))) asyncio.ensure_future(lchannel(reader, writer_remote, m(roption.direct), roption.connection_change)) except Exception as ex: @@ -95,61 +103,12 @@ async def stream_handler(reader, writer, unix, lbind, protos, rserver, cipher, s if debug: raise -async def reuse_stream_handler(reader, writer, unix, lbind, protos, rserver, urserver, block, cipher, salgorithm, verbose=DUMMY, modstat=lambda r,h:lambda i:DUMMY, **kwargs): - try: - if unix: - remote_ip, server_ip, remote_text = 'local', None, 'unix_local' - else: - remote_ip, remote_port, *_ = writer.get_extra_info('peername') - server_ip = writer.get_extra_info('sockname')[0] - remote_text = f'{remote_ip}:{remote_port}' - local_addr = None if server_ip in ('127.0.0.1', '::1', None) else (server_ip, 0) - reader_cipher, _ = await prepare_ciphers(cipher, reader, writer, server_side=False) - lproto = protos[0] - except Exception as ex: - verbose(f'{str(ex) or "Unsupported protocol"} from {remote_ip}') - async def tcp_handler(reader, writer, host_name, port): - try: - if block and block(host_name): - raise Exception('BLOCK ' + host_name) - roption = schedule(rserver, salgorithm, host_name, port) or ProxyURI.DIRECT - verbose(f'{lproto.name} {remote_text}{roption.logtext(host_name, port)}') - try: - reader_remote, writer_remote = await roption.open_connection(host_name, port, local_addr, lbind) - except asyncio.TimeoutError: - raise Exception(f'Connection timeout {roption.bind}') - try: - reader_remote, writer_remote = await roption.prepare_connection(reader_remote, writer_remote, host_name, port) - except Exception: - writer_remote.close() - raise Exception('Unknown remote protocol') - m = modstat(remote_ip, host_name) - asyncio.ensure_future(lproto.channel(reader_remote, writer, m(2+roption.direct), m(4+roption.direct))) - asyncio.ensure_future(lproto.channel(reader, writer_remote, m(roption.direct), roption.connection_change)) - except Exception as ex: - if not isinstance(ex, asyncio.TimeoutError) and not str(ex).startswith('Connection closed'): - verbose(f'{str(ex) or "Unsupported protocol"} from {remote_ip}') - try: writer.close() - except Exception: pass - async def udp_handler(sendto, data, host_name, port, sid): - try: - if block and block(host_name): - raise Exception('BLOCK ' + host_name) - roption = schedule(urserver, salgorithm, host_name, port) or ProxyURI.DIRECT - verbose(f'UDP {lproto.name} {remote_text}{roption.logtext(host_name, port)}') - data = roption.prepare_udp_connection(host_name, port, data) - await roption.open_udp_connection(host_name, port, data, sid, sendto) - except Exception as ex: - if not str(ex).startswith('Connection closed'): - verbose(f'{str(ex) or "Unsupported protocol"} from {remote_ip}') - lproto.get_handler(reader, writer, verbose, tcp_handler, udp_handler) - async def datagram_handler(writer, data, addr, protos, urserver, block, cipher, salgorithm, verbose=DUMMY, **kwargs): try: remote_ip, remote_port, *_ = addr remote_text = f'{remote_ip}:{remote_port}' data = cipher.datagram.decrypt(data) if cipher else data - lproto, host_name, port, data = proto.udp_parse(protos, data, sock=writer.get_extra_info('socket'), **kwargs) + lproto, user, host_name, port, data = proto.udp_accept(protos, data, sock=writer.get_extra_info('socket'), **kwargs) if host_name == 'echo': writer.sendto(data, addr) elif host_name == 'empty': @@ -157,13 +116,13 @@ async def datagram_handler(writer, data, addr, protos, urserver, block, cipher, elif block and block(host_name): raise Exception('BLOCK ' + host_name) else: - roption = schedule(urserver, salgorithm, host_name, port) or ProxyURI.DIRECT + roption = schedule(urserver, salgorithm, host_name, port) or DIRECT verbose(f'UDP {lproto.name} {remote_text}{roption.logtext(host_name, port)}') - data = roption.prepare_udp_connection(host_name, port, data) + data = roption.udp_prepare_connection(host_name, port, data) def reply(rdata): - rdata = lproto.udp_client2(host_name, port, rdata) + rdata = lproto.udp_pack(host_name, port, rdata) writer.sendto(cipher.datagram.encrypt(rdata) if cipher else rdata, addr) - await roption.open_udp_connection(host_name, port, data, addr, reply) + await roption.udp_open_connection(host_name, port, data, addr, reply) except Exception as ex: if not str(ex).startswith('Connection closed'): verbose(f'{str(ex) or "Unsupported protocol"} from {remote_ip}') @@ -172,7 +131,7 @@ async def check_server_alive(interval, rserver, verbose): while True: await asyncio.sleep(interval) for remote in rserver: - if remote.direct: + if type(remote) is ProxyDirect: continue try: _, writer = await remote.open_connection(None, None, None, None, timeout=3) @@ -187,94 +146,34 @@ async def check_server_alive(interval, rserver, verbose): verbose(f'{remote.rproto.name} {remote.bind} -> ONLINE') remote.alive = True try: - if remote.backward: + if isinstance(remote, ProxyBackward): writer.write(b'\x00') writer.close() except Exception: pass -class BackwardConnection(object): - def __init__(self, uri, count): - self.uri = uri - self.count = count - self.closed = False - self.conn = asyncio.Queue() - async def open_connection(self): - while True: - reader, writer = await self.conn.get() - if not writer.transport.is_closing(): - return reader, writer - def close(self): - self.closed = True - try: - self.writer.close() - except Exception: - pass - async def start_server(self, handler): - for _ in range(self.count): - asyncio.ensure_future(self.server_run(handler)) - return self - async def server_run(self, handler): - errwait = 0 - while not self.closed: - if self.uri.unix: - wait = asyncio.open_unix_connection(path=self.uri.bind) - else: - wait = asyncio.open_connection(host=self.uri.host_name, port=self.uri.port, local_addr=(self.uri.lbind, 0) if self.uri.lbind else None) - try: - reader, writer = await asyncio.wait_for(wait, timeout=SOCKET_TIMEOUT) - writer.write(self.uri.auth) - self.writer = writer - try: - data = await reader.read_n(1) - except asyncio.TimeoutError: - data = None - if data and data[0] != 0: - reader._buffer[0:0] = data - asyncio.ensure_future(handler(reader, writer)) - else: - writer.close() - errwait = 0 - except Exception as ex: - try: - writer.close() - except Exception: - pass - if not self.closed: - await asyncio.sleep(errwait) - errwait = min(errwait*1.3 + 0.1, 30) - def client_run(self, args): - async def handler(reader, writer): - if self.uri.auth: - try: - assert self.uri.auth == (await reader.read_n(len(self.uri.auth))) - except Exception: - return - await self.conn.put((reader, writer)) - if self.uri.unix: - return asyncio.start_unix_server(handler, path=self.uri.bind) - else: - return asyncio.start_server(handler, host=self.uri.host_name, port=self.uri.port, reuse_port=args.get('ruport')) - -class ProxyURI(object): - def __init__(self, **kw): - self.__dict__.update(kw) - self.total = 0 +class ProxyDirect(object): + def __init__(self, lbind=None): + self.bind = 'DIRECT' + self.lbind = lbind + self.unix = False + self.alive = True + self.connections = 0 self.udpmap = {} - self.handler = None - self.streams = None - if self.backward: - self.backward = BackwardConnection(self, self.backward) + @property + def direct(self): + return type(self) is ProxyDirect def logtext(self, host, port): - if self.direct: - return f' -> {host}:{port}' - elif self.tunnel: - return f' ->{(" ssl" if self.sslclient else "")} {self.bind}' - else: - return f' -> {self.rproto.name+("+ssl" if self.sslclient else "")} {self.bind}' + self.relay.logtext(host, port) + return '' if host == 'tunnel' else f' -> {host}:{port}' + def match_rule(self, host, port): + return True def connection_change(self, delta): - self.total += delta - async def open_udp_connection(self, host, port, data, addr, reply): + self.connections += delta + def udp_packet_unpack(self, data): + return data + def destination(self, host, port): + return host, port + async def udp_open_connection(self, host, port, data, addr, reply): class Protocol(asyncio.DatagramProtocol): def __init__(prot, data): self.udpmap[addr] = prot @@ -294,8 +193,7 @@ def new_data_arrived(prot, data): prot.databuf.append(data) prot.update = time.perf_counter() def datagram_received(prot, data, addr): - data = self.cipher.datagram.decrypt(data) if self.cipher else data - data = self.rproto.udp_client(data) if not self.direct else data + data = self.udp_packet_unpack(data) reply(data) prot.update = time.perf_counter() def connection_lost(prot, exc): @@ -303,108 +201,31 @@ def connection_lost(prot, exc): if addr in self.udpmap: self.udpmap[addr].new_data_arrived(data) else: - if self.direct and host == 'tunnel': - raise Exception('Unknown tunnel endpoint') self.connection_change(1) if len(self.udpmap) > UDP_LIMIT: min_addr = min(self.udpmap, key=lambda x: self.udpmap[x].update) prot = self.udpmap.pop(min_addr) if prot.transport: prot.transport.close() - prot = Protocol(data) - remote_addr = (host, port) if self.direct else (self.host_name, self.port) - await asyncio.get_event_loop().create_datagram_endpoint(lambda: prot, remote_addr=remote_addr) - def prepare_udp_connection(self, host, port, data): - if not self.direct: - data = self.relay.prepare_udp_connection(host, port, data) - whost, wport = (host, port) if self.relay.direct else (self.relay.host_name, self.relay.port) - data = self.rproto.udp_connect(rauth=self.auth, host_name=whost, port=wport, data=data) - if self.cipher: - data = self.cipher.datagram.encrypt(data) + prot = lambda: Protocol(data) + remote = self.destination(host, port) + await asyncio.get_event_loop().create_datagram_endpoint(prot, remote_addr=remote) + def udp_prepare_connection(self, host, port, data): return data - def start_udp_server(self, args): - class Protocol(asyncio.DatagramProtocol): - def connection_made(prot, transport): - prot.transport = transport - def datagram_received(prot, data, addr): - asyncio.ensure_future(datagram_handler(prot.transport, data, addr, **vars(self), **args)) - return asyncio.get_event_loop().create_datagram_endpoint(Protocol, local_addr=(self.host_name, self.port)) + def wait_open_connection(self, host, port, local_addr, family): + return asyncio.open_connection(host=host, port=port, local_addr=local_addr, family=family) async def open_connection(self, host, port, local_addr, lbind, timeout=SOCKET_TIMEOUT): - if self.reuse or self.ssh: - if self.streams is None or self.streams.done() and (self.reuse and not self.handler): - self.streams = asyncio.get_event_loop().create_future() - else: - if not self.streams.done(): - await self.streams - return self.streams.result() try: local_addr = local_addr if self.lbind == 'in' else (self.lbind, 0) if self.lbind else \ local_addr if lbind == 'in' else (lbind, 0) if lbind else None family = 0 if local_addr is None else socket.AF_INET6 if ':' in local_addr[0] else socket.AF_INET - if self.direct: - if host == 'tunnel': - raise Exception('Unknown tunnel endpoint') - wait = asyncio.open_connection(host=host, port=port, local_addr=local_addr, family=family) - elif self.ssh: - try: - import asyncssh - for s in ('read_', 'read_n', 'read_until'): - setattr(asyncssh.SSHReader, s, getattr(asyncio.StreamReader, s)) - except Exception: - raise Exception('Missing library: "pip3 install asyncssh"') - username, password = self.auth.decode().split(':', 1) - if password.startswith(':'): - client_keys = [password[1:]] - password = None - else: - client_keys = None - conn = await asyncssh.connect(host=self.host_name, port=self.port, local_addr=local_addr, family=family, x509_trusted_certs=None, known_hosts=None, username=username, password=password, client_keys=client_keys, keepalive_interval=60) - if not self.streams.done(): - self.streams.set_result((conn, None)) - return conn, None - elif self.backward: - wait = self.backward.open_connection() - elif self.unix: - wait = asyncio.open_unix_connection(path=self.bind) - else: - wait = asyncio.open_connection(host=self.host_name, port=self.port, local_addr=local_addr, family=family) + wait = self.wait_open_connection(host, port, local_addr, family) reader, writer = await asyncio.wait_for(wait, timeout=timeout) except Exception as ex: - if self.reuse: - self.streams.set_exception(ex) - self.streams = None raise return reader, writer - def prepare_connection(self, reader_remote, writer_remote, host, port): - if self.reuse and not self.handler: - self.handler = self.rproto.get_handler(reader_remote, writer_remote, DUMMY) - return self.prepare_ciphers_and_headers(reader_remote, writer_remote, host, port, self.handler) - async def prepare_ciphers_and_headers(self, reader_remote, writer_remote, host, port, handler): - if not self.direct: - reader_remote, writer_remote = proto.sslwrap(reader_remote, writer_remote, self.sslclient, False, self.host_name) - if not handler or not handler.ready: - _, writer_cipher_r = await prepare_ciphers(self.cipher, reader_remote, writer_remote, self.bind) - else: - writer_cipher_r = None - whost, wport = (host, port) if self.relay.direct else (self.relay.host_name, self.relay.port) - if self.rproto.reuse(): - if not self.streams.done(): - self.streams.set_result((reader_remote, writer_remote)) - reader_remote, writer_remote = handler.connect(whost, wport) - elif self.ssh: - reader_remote, writer_remote = await reader_remote.open_connection(whost, wport) - else: - await self.rproto.connect(reader_remote=reader_remote, writer_remote=writer_remote, rauth=self.auth, host_name=whost, port=wport, writer_cipher_r=writer_cipher_r, myhost=self.host_name, sock=writer_remote.get_extra_info('socket')) - return await self.relay.prepare_ciphers_and_headers(reader_remote, writer_remote, host, port, handler) + async def prepare_connection(self, reader_remote, writer_remote, host, port): return reader_remote, writer_remote - def start_server(self, args): - handler = functools.partial(reuse_stream_handler if self.reuse else stream_handler, **vars(self), **args) - if self.backward: - return self.backward.start_server(handler) - elif self.unix: - return asyncio.start_unix_server(handler, path=self.bind) - else: - return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport')) async def tcp_connect(self, host, port, local_addr=None, lbind=None): reader, writer = await self.open_connection(host, port, local_addr, lbind) try: @@ -416,83 +237,612 @@ async def tcp_connect(self, host, port, local_addr=None, lbind=None): async def udp_sendto(self, host, port, data, answer_cb, local_addr=None): if local_addr is None: local_addr = random.randrange(2**32) - data = self.prepare_udp_connection(host, port, data) - await self.open_udp_connection(host, port, data, local_addr, answer_cb) - @classmethod - def compile_rule(cls, filename): - if filename.startswith("{") and filename.endswith("}"): - return re.compile(filename[1:-1]).match - with open(filename) as f: - return re.compile('(:?'+''.join('|'.join(i.strip() for i in f if i.strip() and not i.startswith('#')))+')$').match - @classmethod - def compile_relay(cls, uri): - tail = cls.DIRECT - for urip in reversed(uri.split('__')): - tail = cls.compile(urip, tail) - return tail - @classmethod - def compile(cls, uri, relay=None): - scheme, _, uri = uri.partition('://') - url = urllib.parse.urlparse('s://'+uri) - rawprotos = scheme.split('+') - err_str, protos = proto.get_protos(rawprotos) - if err_str: - raise argparse.ArgumentTypeError(err_str) - if 'ssl' in rawprotos or 'secure' in rawprotos: - import ssl - sslserver = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - sslclient = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - if 'ssl' in rawprotos: - sslclient.check_hostname = False - sslclient.verify_mode = ssl.CERT_NONE + data = self.udp_prepare_connection(host, port, data) + await self.udp_open_connection(host, port, data, local_addr, answer_cb) +DIRECT = ProxyDirect() + +class ProxySimple(ProxyDirect): + def __init__(self, jump, protos, cipher, users, rule, bind, + host_name, port, unix, lbind, sslclient, sslserver): + super().__init__(lbind) + self.protos = protos + self.cipher = cipher + self.users = users + self.rule = compile_rule(rule) if rule else None + self.bind = bind + self.host_name = host_name + self.port = port + self.unix = unix + self.sslclient = sslclient + self.sslserver = sslserver + self.jump = jump + def logtext(self, host, port): + return f' -> {self.rproto.name+("+ssl" if self.sslclient else "")} {self.bind}' + self.jump.logtext(host, port) + def match_rule(self, host, port): + return (self.rule is None) or self.rule(host) or self.rule(str(port)) + @property + def rproto(self): + return self.protos[0] + @property + def auth(self): + return self.users[0] if self.users else b'' + def udp_packet_unpack(self, data): + data = self.cipher.datagram.decrypt(data) if self.cipher else data + return self.rproto.udp_unpack(data) + def destination(self, host, port): + return self.host_name, self.port + def udp_prepare_connection(self, host, port, data): + data = self.jump.udp_prepare_connection(host, port, data) + whost, wport = self.jump.destination(host, port) + data = self.rproto.udp_connect(rauth=self.auth, host_name=whost, port=wport, data=data) + if self.cipher: + data = self.cipher.datagram.encrypt(data) + return data + def udp_start_server(self, args): + class Protocol(asyncio.DatagramProtocol): + def connection_made(prot, transport): + prot.transport = transport + def datagram_received(prot, data, addr): + asyncio.ensure_future(datagram_handler(prot.transport, data, addr, **vars(self), **args)) + return asyncio.get_event_loop().create_datagram_endpoint(Protocol, local_addr=(self.host_name, self.port)) + def wait_open_connection(self, host, port, local_addr, family): + if self.unix: + return asyncio.open_unix_connection(path=self.bind) else: - sslserver = sslclient = None - protonames = [i.name for i in protos] - if 'pack' in protonames and relay and relay != cls.DIRECT: - raise argparse.ArgumentTypeError('pack protocol cannot relay to other proxy') - urlpath, _, plugins = url.path.partition(',') - urlpath, _, lbind = urlpath.partition('@') - plugins = plugins.split(',') if plugins else None - cipher, _, loc = url.netloc.rpartition('@') - if cipher: - from .cipher import get_cipher - if ':' not in cipher: + return asyncio.open_connection(host=self.host_name, port=self.port, local_addr=local_addr, family=family) + async def prepare_connection(self, reader_remote, writer_remote, host, port): + reader_remote, writer_remote = proto.sslwrap(reader_remote, writer_remote, self.sslclient, False, self.host_name) + _, writer_cipher_r = await prepare_ciphers(self.cipher, reader_remote, writer_remote, self.bind) + whost, wport = self.jump.destination(host, port) + await self.rproto.connect(reader_remote=reader_remote, writer_remote=writer_remote, rauth=self.auth, host_name=whost, port=wport, writer_cipher_r=writer_cipher_r, myhost=self.host_name, sock=writer_remote.get_extra_info('socket')) + return await self.jump.prepare_connection(reader_remote, writer_remote, host, port) + def start_server(self, args, stream_handler=stream_handler): + handler = functools.partial(stream_handler, **vars(self), **args) + if self.unix: + return asyncio.start_unix_server(handler, path=self.bind) + else: + return asyncio.start_server(handler, host=self.host_name, port=self.port, reuse_port=args.get('ruport')) + +class ProxyH2(ProxySimple): + def __init__(self, sslserver, sslclient, **kw): + super().__init__(sslserver=None, sslclient=None, **kw) + self.handshake = None + self.h2sslserver = sslserver + self.h2sslclient = sslclient + async def handler(self, reader, writer, client_side=True, stream_handler=None, **kw): + import h2.connection, h2.config, h2.events + reader, writer = proto.sslwrap(reader, writer, self.h2sslclient if client_side else self.h2sslserver, not client_side, None) + config = h2.config.H2Configuration(client_side=client_side) + conn = h2.connection.H2Connection(config=config) + streams = {} + conn.initiate_connection() + writer.write(conn.data_to_send()) + while not reader.at_eof() and not writer.is_closing(): + try: + data = await reader.read(65636) + if not data: + break + events = conn.receive_data(data) + except Exception: + pass + writer.write(conn.data_to_send()) + for event in events: + if isinstance(event, h2.events.RequestReceived) and not client_side: + if event.stream_id not in streams: + stream_reader, stream_writer = self.get_stream(conn, writer, event.stream_id) + streams[event.stream_id] = (stream_reader, stream_writer) + asyncio.ensure_future(stream_handler(stream_reader, stream_writer)) + else: + stream_reader, stream_writer = streams[event.stream_id] + stream_writer.headers.set_result(event.headers) + elif isinstance(event, h2.events.SettingsAcknowledged) and client_side: + self.handshake.set_result((conn, streams, writer)) + elif isinstance(event, h2.events.DataReceived): + stream_reader, stream_writer = streams[event.stream_id] + stream_reader.feed_data(event.data) + conn.acknowledge_received_data(len(event.data), event.stream_id) + writer.write(conn.data_to_send()) + elif isinstance(event, h2.events.StreamEnded) or isinstance(event, h2.events.StreamReset): + stream_reader, stream_writer = streams[event.stream_id] + stream_reader.feed_eof() + if not stream_writer.closed: + stream_writer.close() + elif isinstance(event, h2.events.ConnectionTerminated): + break + elif isinstance(event, h2.events.WindowUpdated): + if event.stream_id in streams: + stream_reader, stream_writer = streams[event.stream_id] + stream_writer.window_update() + writer.write(conn.data_to_send()) + writer.close() + def get_stream(self, conn, writer, stream_id): + reader = asyncio.StreamReader() + write_buffer = bytearray() + write_wait = asyncio.Event() + write_full = asyncio.Event() + class StreamWriter(): + def __init__(self): + self.closed = False + self.headers = asyncio.get_event_loop().create_future() + def get_extra_info(self, key): + return writer.get_extra_info(key) + def write(self, data): + write_buffer.extend(data) + write_wait.set() + def drain(self): + writer.write(conn.data_to_send()) + return writer.drain() + def is_closing(self): + return self.closed + def close(self): + self.closed = True + write_wait.set() + def window_update(self): + write_full.set() + def send_headers(self, headers): + conn.send_headers(stream_id, headers) + writer.write(conn.data_to_send()) + stream_writer = StreamWriter() + async def write_job(): + while not stream_writer.closed: + while len(write_buffer) > 0: + while conn.local_flow_control_window(stream_id) <= 0: + write_full.clear() + await write_full.wait() + if stream_writer.closed: + break + chunk_size = min(conn.local_flow_control_window(stream_id), len(write_buffer), conn.max_outbound_frame_size) + conn.send_data(stream_id, write_buffer[:chunk_size]) + writer.write(conn.data_to_send()) + del write_buffer[:chunk_size] + if not stream_writer.closed: + write_wait.clear() + await write_wait.wait() + conn.send_data(stream_id, b'', end_stream=True) + writer.write(conn.data_to_send()) + asyncio.ensure_future(write_job()) + return reader, stream_writer + async def wait_h2_connection(self, local_addr, family): + if self.handshake is not None: + if not self.handshake.done(): + await self.handshake + else: + self.handshake = asyncio.get_event_loop().create_future() + reader, writer = await super().wait_open_connection(None, None, local_addr, family) + asyncio.ensure_future(self.handler(reader, writer)) + await self.handshake + return self.handshake.result() + async def wait_open_connection(self, host, port, local_addr, family): + conn, streams, writer = await self.wait_h2_connection(local_addr, family) + stream_id = conn.get_next_available_stream_id() + conn._begin_new_stream(stream_id, stream_id%2) + stream_reader, stream_writer = self.get_stream(conn, writer, stream_id) + streams[stream_id] = (stream_reader, stream_writer) + return stream_reader, stream_writer + def start_server(self, args, stream_handler=stream_handler): + handler = functools.partial(stream_handler, **vars(self), **args) + return super().start_server(args, functools.partial(self.handler, client_side=False, stream_handler=handler)) + +class ProxyQUIC(ProxySimple): + def __init__(self, quicserver, quicclient, **kw): + super().__init__(**kw) + self.quicserver = quicserver + self.quicclient = quicclient + self.handshake = None + def patch_writer(self, writer): + async def drain(): + writer._transport.protocol.transmit() + #print('stream_id', writer.get_extra_info("stream_id")) + remote_addr = writer._transport.protocol._quic._network_paths[0].addr + writer.get_extra_info = dict(peername=remote_addr, sockname=remote_addr).get + writer.drain = drain + closed = False + writer.is_closing = lambda: closed + def close(): + nonlocal closed + closed = True + try: + writer.write_eof() + except Exception: + pass + writer.close = close + async def wait_quic_connection(self): + if self.handshake is not None: + if not self.handshake.done(): + await self.handshake + else: + self.handshake = asyncio.get_event_loop().create_future() + import aioquic.asyncio, aioquic.quic.events + class Protocol(aioquic.asyncio.QuicConnectionProtocol): + def quic_event_received(s, event): + if isinstance(event, aioquic.quic.events.HandshakeCompleted): + self.handshake.set_result(s) + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self.handshake = None + self.quic_egress_acm = None + elif isinstance(event, aioquic.quic.events.StreamDataReceived): + if event.stream_id in self.udpmap: + self.udpmap[event.stream_id](self.udp_packet_unpack(event.data)) + return + super().quic_event_received(event) + self.quic_egress_acm = aioquic.asyncio.connect(self.host_name, self.port, create_protocol=Protocol, configuration=self.quicclient) + conn = await self.quic_egress_acm.__aenter__() + await self.handshake + async def udp_open_connection(self, host, port, data, addr, reply): + await self.wait_quic_connection() + conn = self.handshake.result() + if addr in self.udpmap: + stream_id = self.udpmap[addr] + else: + stream_id = conn._quic.get_next_available_stream_id(False) + self.udpmap[addr] = stream_id + self.udpmap[stream_id] = reply + conn._quic._get_or_create_stream_for_send(stream_id) + conn._quic.send_stream_data(stream_id, data, False) + conn.transmit() + async def wait_open_connection(self, *args): + await self.wait_quic_connection() + conn = self.handshake.result() + stream_id = conn._quic.get_next_available_stream_id(False) + conn._quic._get_or_create_stream_for_send(stream_id) + reader, writer = conn._create_stream(stream_id) + self.patch_writer(writer) + return reader, writer + async def udp_start_server(self, args): + import aioquic.asyncio, aioquic.quic.events + class Protocol(aioquic.asyncio.QuicConnectionProtocol): + def quic_event_received(s, event): + if isinstance(event, aioquic.quic.events.StreamDataReceived): + stream_id = event.stream_id + addr = ('quic '+self.bind, stream_id) + event.sendto = lambda data, addr: (s._quic.send_stream_data(stream_id, data, False), s.transmit()) + event.get_extra_info = {}.get + asyncio.ensure_future(datagram_handler(event, event.data, addr, **vars(self), **args)) + return + super().quic_event_received(event) + return await aioquic.asyncio.serve(self.host_name, self.port, configuration=self.quicserver, create_protocol=Protocol), None + def start_server(self, args, stream_handler=stream_handler): + import aioquic.asyncio + def handler(reader, writer): + self.patch_writer(writer) + asyncio.ensure_future(stream_handler(reader, writer, **vars(self), **args)) + return aioquic.asyncio.serve(self.host_name, self.port, configuration=self.quicserver, stream_handler=handler) + +class ProxyH3(ProxyQUIC): + def get_stream(self, conn, stream_id): + remote_addr = conn._quic._network_paths[0].addr + reader = asyncio.StreamReader() + class StreamWriter(): + def __init__(self): + self.closed = False + self.headers = asyncio.get_event_loop().create_future() + def get_extra_info(self, key): + return dict(peername=remote_addr, sockname=remote_addr).get(key) + def write(self, data): + conn.http.send_data(stream_id, data, False) + conn.transmit() + async def drain(self): + conn.transmit() + def is_closing(self): + return self.closed + def close(self): + if not self.closed: + conn.http.send_data(stream_id, b'', True) + conn.transmit() + conn.close_stream(stream_id) + self.closed = True + def send_headers(self, headers): + conn.http.send_headers(stream_id, [(i.encode(), j.encode()) for i, j in headers]) + conn.transmit() + return reader, StreamWriter() + def get_protocol(self, server_side=False, handler=None): + import aioquic.asyncio, aioquic.quic.events, aioquic.h3.connection, aioquic.h3.events + class Protocol(aioquic.asyncio.QuicConnectionProtocol): + def __init__(s, *args, **kw): + super().__init__(*args, **kw) + s.http = aioquic.h3.connection.H3Connection(s._quic) + s.streams = {} + def quic_event_received(s, event): + if not server_side: + if isinstance(event, aioquic.quic.events.HandshakeCompleted): + self.handshake.set_result(s) + elif isinstance(event, aioquic.quic.events.ConnectionTerminated): + self.handshake = None + self.quic_egress_acm = None + if s.http is not None: + for http_event in s.http.handle_event(event): + s.http_event_received(http_event) + def http_event_received(s, event): + if isinstance(event, aioquic.h3.events.HeadersReceived): + if event.stream_id not in s.streams and server_side: + reader, writer = s.create_stream(event.stream_id) + writer.headers.set_result(event.headers) + asyncio.ensure_future(handler(reader, writer)) + elif isinstance(event, aioquic.h3.events.DataReceived) and event.stream_id in s.streams: + reader, writer = s.streams[event.stream_id] + if event.data: + reader.feed_data(event.data) + if event.stream_ended: + reader.feed_eof() + s.close_stream(event.stream_id) + def create_stream(s, stream_id=None): + if stream_id is None: + stream_id = s._quic.get_next_available_stream_id(False) + s._quic._get_or_create_stream_for_send(stream_id) + reader, writer = self.get_stream(s, stream_id) + s.streams[stream_id] = (reader, writer) + return reader, writer + def close_stream(s, stream_id): + if stream_id in s.streams: + reader, writer = s.streams[stream_id] + if reader.at_eof() and writer.is_closing(): + s.streams.pop(stream_id) + return Protocol + async def wait_h3_connection(self): + if self.handshake is not None: + if not self.handshake.done(): + await self.handshake + else: + import aioquic.asyncio + self.handshake = asyncio.get_event_loop().create_future() + self.quic_egress_acm = aioquic.asyncio.connect(self.host_name, self.port, create_protocol=self.get_protocol(), configuration=self.quicclient) + conn = await self.quic_egress_acm.__aenter__() + await self.handshake + async def wait_open_connection(self, *args): + await self.wait_h3_connection() + return self.handshake.result().create_stream() + def start_server(self, args, stream_handler=stream_handler): + import aioquic.asyncio + return aioquic.asyncio.serve(self.host_name, self.port, configuration=self.quicserver, create_protocol=self.get_protocol(True, functools.partial(stream_handler, **vars(self), **args))) + +class ProxySSH(ProxySimple): + def __init__(self, **kw): + super().__init__(**kw) + self.sshconn = None + def logtext(self, host, port): + return f' -> sshtunnel {self.bind}' + self.jump.logtext(host, port) + def patch_stream(self, ssh_reader, writer, host, port): + reader = asyncio.StreamReader() + async def channel(): + while not ssh_reader.at_eof() and not writer.is_closing(): + buf = await ssh_reader.read(65536) + if not buf: + break + reader.feed_data(buf) + reader.feed_eof() + asyncio.ensure_future(channel()) + remote_addr = ('ssh:'+str(host), port) + writer.get_extra_info = dict(peername=remote_addr, sockname=remote_addr).get + return reader, writer + async def wait_ssh_connection(self, local_addr=None, family=0, tunnel=None): + if self.sshconn is not None: + if not self.sshconn.done(): + await self.sshconn + else: + self.sshconn = asyncio.get_event_loop().create_future() + try: + import asyncssh + except Exception: + raise Exception('Missing library: "pip3 install asyncssh"') + username, password = self.auth.decode().split(':', 1) + if password.startswith(':'): + client_keys = [password[1:]] + password = None + else: + client_keys = None + conn = await asyncssh.connect(host=self.host_name, port=self.port, local_addr=local_addr, family=family, x509_trusted_certs=None, known_hosts=None, username=username, password=password, client_keys=client_keys, keepalive_interval=60, tunnel=tunnel) + self.sshconn.set_result(conn) + async def wait_open_connection(self, host, port, local_addr, family, tunnel=None): + await self.wait_ssh_connection(local_addr, family, tunnel) + conn = self.sshconn.result() + if isinstance(self.jump, ProxySSH): + reader, writer = await self.jump.wait_open_connection(host, port, None, None, conn) + else: + host, port = self.jump.destination(host, port) + if self.jump.unix: + reader, writer = await conn.open_unix_connection(self.jump.bind) + else: + reader, writer = await conn.open_connection(host, port) + reader, writer = self.patch_stream(reader, writer, host, port) + return reader, writer + async def start_server(self, args, stream_handler=stream_handler, tunnel=None): + if type(self.jump) is ProxyDirect: + raise Exception('ssh server mode unsupported') + await self.wait_ssh_connection(tunnel=tunnel) + conn = self.sshconn.result() + if isinstance(self.jump, ProxySSH): + return await self.jump.start_server(args, stream_handler, conn) + else: + def handler(host, port): + def handler_stream(reader, writer): + reader, writer = self.patch_stream(reader, writer, host, port) + return stream_handler(reader, writer, **vars(self.jump), **args) + return handler_stream + if self.jump.unix: + return await conn.start_unix_server(handler, self.jump.bind) + else: + return await conn.start_server(handler, self.jump.host_name, self.jump.port) + +class ProxyBackward(ProxySimple): + def __init__(self, backward, backward_num, **kw): + super().__init__(**kw) + self.backward = backward + self.server = backward + while type(self.server.jump) != ProxyDirect: + self.server = self.server.jump + self.backward_num = backward_num + self.closed = False + self.writers = set() + self.conn = asyncio.Queue() + async def wait_open_connection(self, *args): + while True: + reader, writer = await self.conn.get() + if not reader.at_eof() and not writer.is_closing(): + return reader, writer + def close(self): + self.closed = True + for writer in self.writers: + try: + self.writer.close() + except Exception: + pass + async def start_server(self, args, stream_handler=stream_handler): + handler = functools.partial(stream_handler, **vars(self.server), **args) + for _ in range(self.backward_num): + asyncio.ensure_future(self.start_server_run(handler)) + return self + async def start_server_run(self, handler): + errwait = 0 + while not self.closed: + wait = self.backward.open_connection(self.host_name, self.port, self.lbind, None) + try: + reader, writer = await asyncio.wait_for(wait, timeout=SOCKET_TIMEOUT) + if self.closed: + writer.close() + break + if isinstance(self.server, ProxyQUIC): + writer.write(b'\x01') + writer.write(self.server.auth) + self.writers.add(writer) try: - cipher = base64.b64decode(cipher).decode() + data = await reader.read_n(1) + except asyncio.TimeoutError: + data = None + if data and data[0] != 0: + reader.rollback(data) + asyncio.ensure_future(handler(reader, writer)) + else: + writer.close() + errwait = 0 + self.writers.discard(writer) + writer = None + except Exception as ex: + try: + writer.close() except Exception: pass - if ':' not in cipher: - raise argparse.ArgumentTypeError('userinfo must be "cipher:key"') - err_str, cipher = get_cipher(cipher) - if err_str: - raise argparse.ArgumentTypeError(err_str) - if plugins: - from .plugin import get_plugin - for name in plugins: - if not name: continue - err_str, plugin = get_plugin(name) - if err_str: - raise argparse.ArgumentTypeError(err_str) - cipher.plugins.append(plugin) - match = cls.compile_rule(url.query) if url.query else None - if loc: - host_name, _, port = loc.rpartition(':') - port = int(port) if port else (22 if 'ssh' in rawprotos else 8080) + if not self.closed: + await asyncio.sleep(errwait) + errwait = min(errwait*1.3 + 0.1, 30) + def start_backward_client(self, args): + async def handler(reader, writer, **kw): + auth = self.server.auth + if isinstance(self.server, ProxyQUIC): + auth = b'\x01'+auth + if auth: + try: + assert auth == (await reader.read_n(len(auth))) + except Exception: + return + await self.conn.put((reader, writer)) + return self.backward.start_server(args, handler) + + +def compile_rule(filename): + if filename.startswith("{") and filename.endswith("}"): + return re.compile(filename[1:-1]).match + with open(filename) as f: + return re.compile('(:?'+''.join('|'.join(i.strip() for i in f if i.strip() and not i.startswith('#')))+')$').match + +def proxies_by_uri(uri_jumps): + jump = DIRECT + for uri in reversed(uri_jumps.split('__')): + jump = proxy_by_uri(uri, jump) + return jump + +sslcontexts = [] + +def proxy_by_uri(uri, jump): + scheme, _, uri = uri.partition('://') + url = urllib.parse.urlparse('s://'+uri) + rawprotos = [i.lower() for i in scheme.split('+')] + err_str, protos = proto.get_protos(rawprotos) + protonames = [i.name for i in protos] + if err_str: + raise argparse.ArgumentTypeError(err_str) + if 'ssl' in rawprotos or 'secure' in rawprotos: + import ssl + sslserver = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + sslclient = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + if 'ssl' in rawprotos: + sslclient.check_hostname = False + sslclient.verify_mode = ssl.CERT_NONE + sslcontexts.append(sslserver) + sslcontexts.append(sslclient) + else: + sslserver = sslclient = None + if 'quic' in rawprotos or 'h3' in protonames: + try: + import ssl, aioquic.quic.configuration + except Exception: + raise Exception('Missing library: "pip3 install aioquic"') + quicserver = aioquic.quic.configuration.QuicConfiguration(is_client=False, max_stream_data=2**60, max_data=2**60, idle_timeout=SOCKET_TIMEOUT) + quicclient = aioquic.quic.configuration.QuicConfiguration(max_stream_data=2**60, max_data=2**60, idle_timeout=SOCKET_TIMEOUT*5) + quicclient.verify_mode = ssl.CERT_NONE + sslcontexts.append(quicserver) + sslcontexts.append(quicclient) + if 'h2' in rawprotos: + try: + import h2 + except Exception: + raise Exception('Missing library: "pip3 install h2"') + urlpath, _, plugins = url.path.partition(',') + urlpath, _, lbind = urlpath.partition('@') + plugins = plugins.split(',') if plugins else None + cipher, _, loc = url.netloc.rpartition('@') + if cipher: + from .cipher import get_cipher + if ':' not in cipher: + try: + cipher = base64.b64decode(cipher).decode() + except Exception: + pass + if ':' not in cipher: + raise argparse.ArgumentTypeError('userinfo must be "cipher:key"') + err_str, cipher = get_cipher(cipher) + if err_str: + raise argparse.ArgumentTypeError(err_str) + if plugins: + from .plugin import get_plugin + for name in plugins: + if not name: continue + err_str, plugin = get_plugin(name) + if err_str: + raise argparse.ArgumentTypeError(err_str) + cipher.plugins.append(plugin) + if loc: + host_name, port = proto.netloc_split(loc, default_port=22 if 'ssh' in rawprotos else 8080) + else: + host_name = port = None + if url.fragment.startswith('#'): + with open(url.fragment[1:]) as f: + auth = f.read().rstrip().encode() + else: + auth = url.fragment.encode() + users = [i.rstrip() for i in auth.split(b'\n')] if auth else None + if 'direct' in protonames: + return ProxyDirect(lbind=lbind) + else: + params = dict(jump=jump, protos=protos, cipher=cipher, users=users, rule=url.query, bind=loc or urlpath, + host_name=host_name, port=port, unix=not loc, lbind=lbind, sslclient=sslclient, sslserver=sslserver) + if 'quic' in rawprotos: + proxy = ProxyQUIC(quicserver, quicclient, **params) + elif 'h3' in protonames: + proxy = ProxyH3(quicserver, quicclient, **params) + elif 'h2' in protonames: + proxy = ProxyH2(**params) + elif 'ssh' in protonames: + proxy = ProxySSH(**params) else: - host_name = port = None - return ProxyURI(protos=protos, rproto=protos[0], cipher=cipher, auth=url.fragment.encode(), \ - match=match, bind=loc or urlpath, host_name=host_name, port=port, \ - unix=not loc, lbind=lbind, sslclient=sslclient, sslserver=sslserver, \ - alive=True, direct='direct' in protonames, tunnel='tunnel' in protonames, \ - reuse='pack' in protonames or relay and relay.reuse, backward=rawprotos.count('in'), \ - ssh='ssh' in rawprotos, relay=relay) -ProxyURI.DIRECT = ProxyURI(direct=True, tunnel=False, reuse=False, relay=None, alive=True, match=None, cipher=None, backward=None, ssh=None, lbind=None) + proxy = ProxySimple(**params) + if 'in' in rawprotos: + proxy = ProxyBackward(proxy, rawprotos.count('in'), **params) + return proxy async def test_url(url, rserver): url = urllib.parse.urlparse(url) assert url.scheme in ('http', 'https'), f'Unknown scheme {url.scheme}' - host_name, _, port = url.netloc.partition(':') - port = int(port) if port else 80 if url.scheme == 'http' else 443 + host_name, port = proto.netloc_split(url.netloc, default_port = 80 if url.scheme=='http' else 443) initbuf = f'GET {url.path or "/"} HTTP/1.1\r\nHost: {host_name}\r\nUser-Agent: pproxy-{__version__}\r\nAccept: */*\r\nConnection: close\r\n\r\n'.encode() for roption in rserver: print(f'============ {roption.bind} ============') @@ -516,8 +866,8 @@ async def test_url(url, rserver): print(headers.decode()[:-4]) print(f'--------------------------------') body = bytearray() - while 1: - s = await reader.read_() + while not reader.at_eof(): + s = await reader.read(65536) if not s: break body.extend(s) @@ -526,11 +876,11 @@ async def test_url(url, rserver): def main(): parser = argparse.ArgumentParser(description=__description__+'\nSupported protocols: http,socks4,socks5,shadowsocks,shadowsocksr,redirect,pf,tunnel', epilog=f'Online help: <{__url__}>') - parser.add_argument('-l', dest='listen', default=[], action='append', type=ProxyURI.compile, help='tcp server uri (default: http+socks4+socks5://:8080/)') - parser.add_argument('-r', dest='rserver', default=[], action='append', type=ProxyURI.compile_relay, help='tcp remote server uri (default: direct)') - parser.add_argument('-ul', dest='ulisten', default=[], action='append', type=ProxyURI.compile, help='udp server setting uri (default: none)') - parser.add_argument('-ur', dest='urserver', default=[], action='append', type=ProxyURI.compile_relay, help='udp remote server uri (default: direct)') - parser.add_argument('-b', dest='block', type=ProxyURI.compile_rule, help='block regex rules') + parser.add_argument('-l', dest='listen', default=[], action='append', type=proxies_by_uri, help='tcp server uri (default: http+socks4+socks5://:8080/)') + parser.add_argument('-r', dest='rserver', default=[], action='append', type=proxies_by_uri, help='tcp remote server uri (default: direct)') + parser.add_argument('-ul', dest='ulisten', default=[], action='append', type=proxies_by_uri, help='udp server setting uri (default: none)') + parser.add_argument('-ur', dest='urserver', default=[], action='append', type=proxies_by_uri, help='udp remote server uri (default: direct)') + parser.add_argument('-b', dest='block', type=compile_rule, help='block regex rules') parser.add_argument('-a', dest='alived', default=0, type=int, help='interval to check remote alive (default: no check)') parser.add_argument('-s', dest='salgorithm', default='fa', choices=('fa', 'rr', 'rc', 'lc'), help='scheduling algorithm (default: first_available)') parser.add_argument('-d', dest='debug', action='count', help='turn on debug to see tracebacks (default: no debug)') @@ -545,16 +895,23 @@ def main(): parser.add_argument('--test', help='test this url for all remote proxies and exit') parser.add_argument('--version', action='version', version=f'%(prog)s {__version__}') args = parser.parse_args() + if args.sslfile: + sslfile = args.sslfile.split(',') + for context in sslcontexts: + context.load_cert_chain(*sslfile) + elif any(map(lambda o: o.sslclient or isinstance(o, ProxyQUIC), args.listen+args.ulisten)): + print('You must specify --ssl to listen in ssl mode') + return if args.test: asyncio.get_event_loop().run_until_complete(test_url(args.test, args.rserver)) return if not args.listen and not args.ulisten: - args.listen.append(ProxyURI.compile_relay('http+socks4+socks5://:8080/')) + args.listen.append(proxies_by_uri('http+socks4+socks5://:8080/')) args.httpget = {} if args.pac: pactext = 'function FindProxyForURL(u,h){' + (f'var b=/^(:?{args.block.__self__.pattern})$/i;if(b.test(h))return "";' if args.block else '') for i, option in enumerate(args.rserver): - pactext += (f'var m{i}=/^(:?{option.match.__self__.pattern})$/i;if(m{i}.test(h))' if option.match else '') + 'return "PROXY %(host)s";' + pactext += (f'var m{i}=/^(:?{option.rule.__self__.pattern})$/i;if(m{i}.test(h))' if option.rule else '') + 'return "PROXY %(host)s";' args.httpget[args.pac] = pactext+'return "DIRECT";}' args.httpget[args.pac+'/all'] = 'function FindProxyForURL(u,h){return "PROXY %(host)s";}' args.httpget[args.pac+'/none'] = 'function FindProxyForURL(u,h){return "DIRECT";}' @@ -562,15 +919,6 @@ def main(): path, filename = gets.split(',', 1) with open(filename, 'rb') as f: args.httpget[path] = f.read() - if args.sslfile: - sslfile = args.sslfile.split(',') - for option in args.listen: - if option.sslclient: - option.sslclient.load_cert_chain(*sslfile) - option.sslserver.load_cert_chain(*sslfile) - elif any(map(lambda o: o.sslclient, args.listen)): - print('You must specify --ssl to listen in ssl mode') - return if args.daemon: try: __import__('daemon').DaemonContext().open() @@ -598,15 +946,15 @@ def main(): for option in args.ulisten: print('Serving on UDP', option.bind, 'by', ",".join(i.name for i in option.protos), f'({option.cipher.name})' if option.cipher else '') try: - server, protocol = loop.run_until_complete(option.start_udp_server(vars(args))) + server, protocol = loop.run_until_complete(option.udp_start_server(vars(args))) servers.append(server) except Exception as ex: print('Start server failed.\n\t==>', ex) for option in args.rserver: - if option.backward: + if isinstance(option, ProxyBackward): print('Serving on', option.bind, 'backward by', ",".join(i.name for i in option.protos) + ('(SSL)' if option.sslclient else ''), '({}{})'.format(option.cipher.name, ' '+','.join(i.name() for i in option.cipher.plugins) if option.cipher and option.cipher.plugins else '') if option.cipher else '') try: - server = loop.run_until_complete(option.backward.client_run(vars(args))) + server = loop.run_until_complete(option.start_backward_client(vars(args))) servers.append(server) except Exception as ex: print('Start server failed.\n\t==>', ex) @@ -622,7 +970,7 @@ def main(): print('exit') if args.sys: args.sys.clear() - for task in asyncio.Task.all_tasks(): + for task in asyncio.all_tasks(loop) if hasattr(asyncio, 'all_tasks') else asyncio.Task.all_tasks(): task.cancel() for server in servers: server.close() diff --git a/pproxy/verbose.py b/pproxy/verbose.py index 6c252e9..ab47436 100644 --- a/pproxy/verbose.py +++ b/pproxy/verbose.py @@ -51,9 +51,10 @@ def verbose(s): sys.stdout.flush() args.verbose = verbose args.stats = {0: [0]*6} - def modstat(remote_ip, host_name, stats=args.stats): + def modstat(user, remote_ip, host_name, stats=args.stats): + u = user.decode().split(':')[0]+':' if isinstance(user, (bytes,bytearray)) else '' host_name_2 = '.'.join(host_name.split('.')[-3 if host_name.endswith('.com.cn') else -2:]) if host_name.split('.')[-1].isalpha() else host_name - tostat = (stats[0], stats.setdefault(remote_ip, {}).setdefault(host_name_2, [0]*6)) + tostat = (stats[0], stats.setdefault(u+remote_ip, {}).setdefault(host_name_2, [0]*6)) return lambda i: lambda s: [st.__setitem__(i, st[i] + s) for st in tostat] args.modstat = modstat def win_readline(handler): diff --git a/setup.py b/setup.py index 652cacb..8043935 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def find_value(name): setup( name = find_value('title'), - version = find_value('version'), + use_scm_version = True, description = find_value('description'), long_description = read('README.rst'), url = find_value('url'), @@ -41,11 +41,14 @@ def find_value(name): 'uvloop >= 0.13.0' ], 'sshtunnel': [ - 'asyncssh >= 1.16.0', + 'asyncssh >= 2.5.0', + ], + 'quic': [ + 'aioquic >= 0.9.7', ], 'daemon': [ 'python-daemon >= 2.2.3', - ] + ], }, install_requires = [], entry_points = {