Skip to content

Commit

Permalink
Introduce Async::Redis::Endpoint.
Browse files Browse the repository at this point in the history
- Handles authentication and database selection.
  • Loading branch information
ioquatix committed Aug 15, 2024
1 parent a363147 commit 29ad229
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 15 deletions.
8 changes: 2 additions & 6 deletions lib/async/redis/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
require_relative 'context/pipeline'
require_relative 'context/transaction'
require_relative 'context/subscribe'
require_relative 'protocol/resp2'
require_relative 'endpoint'

require 'io/endpoint/host_endpoint'
require 'async/pool/controller'
Expand All @@ -23,14 +23,10 @@ module Redis
# Legacy.
ServerError = ::Protocol::Redis::ServerError

def self.local_endpoint(port: 6379)
::IO::Endpoint.tcp('localhost', port)
end

class Client
include ::Protocol::Redis::Methods

def initialize(endpoint = Redis.local_endpoint, protocol: Protocol::RESP2, **options)
def initialize(endpoint = Endpoint.local, protocol: endpoint.protocol, **options)
@endpoint = endpoint
@protocol = protocol

Expand Down
250 changes: 250 additions & 0 deletions lib/async/redis/endpoint.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# frozen_string_literal: true

# Released under the MIT License.
# Copyright, 2024, by Samuel Williams.

require 'io/endpoint'
require 'io/endpoint/host_endpoint'
require 'io/endpoint/ssl_endpoint'

require_relative 'protocol/resp2'
require_relative 'protocol/authenticated'
require_relative 'protocol/selected'

module Async
module Redis
def self.local_endpoint(**options)
Endpoint.local(**options)
end

# Represents a way to connect to a remote Redis server.
class Endpoint < ::IO::Endpoint::Generic
LOCALHOST = URI.parse("redis://localhost").freeze

def self.local(**options)
self.new(LOCALHOST, **options)
end

SCHEMES = {
'redis' => URI::Generic,
'rediss' => URI::Generic,
}

def self.parse(string, endpoint = nil, **options)
url = URI.parse(string).normalize

return self.new(url, endpoint, **options)
end

# Construct an endpoint with a specified scheme, hostname, optional path, and options.
#
# @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss".
# @parameter hostname [String] The hostname to connect to (or bind to).
# @parameter *options [Hash] Additional options, passed to {#initialize}.
def self.for(scheme, hostname, credentials: nil, port: nil, database: nil, **options)
uri_klass = SCHEMES.fetch(scheme.downcase) do
raise ArgumentError, "Unsupported scheme: #{scheme.inspect}"
end

if database
path = "/#{database}"
end

self.new(
uri_klass.new(scheme, credentials&.join(":"), hostname, port, nil, path, nil, nil, nil).normalize,
**options
)
end

# Coerce the given object into an endpoint.
# @parameter url [String | Endpoint] The URL or endpoint to convert.
def self.[](object)
if object.is_a?(self)
return object
else
self.parse(object.to_s)
end
end

# @option scheme [String] the scheme to use, overrides the URL scheme.
# @option hostname [String] the hostname to connect to (or bind to), overrides the URL hostname (used for SNI).
# @option port [Integer] the port to bind to, overrides the URL port.
# @option ssl_context [OpenSSL::SSL::SSLContext] the context to use for TLS.
# @option alpn_protocols [Array<String>] the alpn protocols to negotiate.
def initialize(url, endpoint = nil, **options)
super(**options)

raise ArgumentError, "URL must be absolute (include scheme, host): #{url}" unless url.absolute?

@url = url

if endpoint
@endpoint = self.build_endpoint(endpoint)
else
@endpoint = nil
end
end

def to_url
url = @url.dup

unless default_port?
url.port = self.port
end

return url
end

def to_s
"\#<#{self.class} #{self.to_url} #{@options}>"
end

def inspect
"\#<#{self.class} #{self.to_url} #{@options.inspect}>"
end

attr :url

def address
endpoint.address
end

def secure?
['rediss'].include?(self.scheme)
end

def protocol
protocol = @options.fetch(:protocol, Protocol::RESP2)

if database = self.database
protocol = Protocol::Selected.new(database, protocol)
end

if credentials = self.credentials
protocol = Protocol::Authenticated.new(credentials, protocol)
end

return protocol
end

def default_port
6379
end

def default_port?
port == default_port
end

def port
@options[:port] || @url.port || default_port
end

# The hostname is the server we are connecting to:
def hostname
@options[:hostname] || @url.hostname
end

def scheme
@options[:scheme] || @url.scheme
end

def database
@options[:database] || @url.path[1..-1].to_i
end

def credentials
@options[:credentials] || @url.userinfo&.split(":")
end

def localhost?
@url.hostname =~ /^(.*?\.)?localhost\.?$/
end

# We don't try to validate peer certificates when talking to localhost because they would always be self-signed.
def ssl_verify_mode
if self.localhost?
OpenSSL::SSL::VERIFY_NONE
else
OpenSSL::SSL::VERIFY_PEER
end
end

def ssl_context
@options[:ssl_context] || OpenSSL::SSL::SSLContext.new.tap do |context|
context.set_params(
verify_mode: self.ssl_verify_mode
)
end
end

def build_endpoint(endpoint = nil)
endpoint ||= tcp_endpoint

if secure?
# Wrap it in SSL:
return ::IO::Endpoint::SSLEndpoint.new(endpoint,
ssl_context: self.ssl_context,
hostname: @url.hostname,
timeout: self.timeout,
)
end

return endpoint
end

def endpoint
@endpoint ||= build_endpoint
end

def endpoint=(endpoint)
@endpoint = build_endpoint(endpoint)
end

def bind(*arguments, &block)
endpoint.bind(*arguments, &block)
end

def connect(&block)
endpoint.connect(&block)
end

def each
return to_enum unless block_given?

self.tcp_endpoint.each do |endpoint|
yield self.class.new(@url, endpoint, **@options)
end
end

def key
[@url, @options]
end

def eql? other
self.key.eql? other.key
end

def hash
self.key.hash
end

protected

def tcp_options
options = @options.dup

options.delete(:scheme)
options.delete(:port)
options.delete(:hostname)
options.delete(:ssl_context)
options.delete(:protocol)

return options
end

def tcp_endpoint
::IO::Endpoint.tcp(self.hostname, port, **tcp_options)
end
end
end
end
2 changes: 1 addition & 1 deletion lib/async/redis/protocol/authenticated.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AuthenticationError < StandardError
#
# @parameter credentials [Array] The credentials to use for authentication.
# @parameter protocol [Object] The delegated protocol for connecting.
def initialize(credentials, protocol: Async::Redis::Protocol::RESP2)
def initialize(credentials, protocol = Async::Redis::Protocol::RESP2)
@credentials = credentials
@protocol = protocol
end
Expand Down
2 changes: 1 addition & 1 deletion lib/async/redis/protocol/selected.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SelectionError < StandardError
#
# @parameter index [Integer] The database index to select.
# @parameter protocol [Object] The delegated protocol for connecting.
def initialize(index, protocol: Async::Redis::Protocol::RESP2)
def initialize(index, protocol = Async::Redis::Protocol::RESP2)
@index = index
@protocol = protocol
end
Expand Down
26 changes: 19 additions & 7 deletions test/async/redis/disconnect.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,30 @@

describe Async::Redis::Client do
include Sus::Fixtures::Async::ReactorContext

let(:endpoint) {::IO::Endpoint.tcp('localhost', 5555)}


# Intended to not be connected:
let(:endpoint) {Async::Redis::Endpoint.local(port: 5555)}

before do
@server_endpoint = ::IO::Endpoint.tcp("localhost").bound
end

after do
@server_endpoint&.close
end

it "should raise error on unexpected disconnect" do
server_task = reactor.async do
endpoint.accept do |connection|
server_task = Async do
@server_endpoint.accept do |connection|
connection.read(8)
connection.close
end
end

client = Async::Redis::Client.new(endpoint)

client = Async::Redis::Client.new(
@server_endpoint.local_address_endpoint,
protocol: Async::Redis::Protocol::RESP2,
)

expect do
client.call("GET", "test")
Expand Down
43 changes: 43 additions & 0 deletions test/async/redis/endpoint.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# frozen_string_literal: true

# Released under the MIT License.
# Copyright, 2024, by Samuel Williams.

require 'async/redis/client'
require 'async/redis/protocol/authenticated'
require 'sus/fixtures/async'

describe Async::Redis::Protocol::Authenticated do
include Sus::Fixtures::Async::ReactorContext

let(:endpoint) {Async::Redis.local_endpoint}
let(:credentials) {["testuser", "testpassword"]}
let(:protocol) {subject.new(credentials)}
let(:client) {Async::Redis::Client.new(endpoint, protocol: protocol)}

before do
# Setup ACL user with limited permissions for testing.
admin_client = Async::Redis::Client.new(endpoint)
admin_client.call("ACL", "SETUSER", "testuser", "on", ">" + credentials[1], "+ping", "+auth")
ensure
admin_client.close
end

after do
# Cleanup ACL user after tests.
admin_client = Async::Redis::Client.new(endpoint)
admin_client.call("ACL", "DELUSER", "testuser")
admin_client.close
end

it "can authenticate and send allowed commands" do
response = client.call("PING")
expect(response).to be == "PONG"
end

it "rejects commands not allowed by ACL" do
expect do
client.call("SET", "key", "value")
end.to raise_exception(Protocol::Redis::ServerError, message: be =~ /NOPERM/)
end
end

0 comments on commit 29ad229

Please sign in to comment.