-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathconninfo.cr
183 lines (155 loc) · 5.39 KB
/
conninfo.cr
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
require "uri"
require "http"
require "system/user"
module PQ
struct ConnInfo
SOCKET_SEARCH = %w(/run/postgresql/.s.PGSQL.5432 /tmp/.s.PGSQL.5432 /var/run/postgresql/.s.PGSQL.5432)
SUPPORTED_AUTH_METHODS = %w[cleartext md5 scram-sha-256 scram-sha-256-plus]
# The host. If starts with a / it is assumed to be a local Unix socket.
getter host : String
# The port, defaults to 5432. It is ignored for local Unix sockets.
getter port : Int32
# The database name.
getter database : String
# The user.
getter user : String
# The password. Optional.
getter password : String?
# The sslmode. Optional (:prefer is default).
getter sslmode : Symbol
# The sslcert. Optional.
getter sslcert : String?
# The sslkey. Optional.
getter sslkey : String?
# The sslrootcert. Optional.
getter sslrootcert : String?
# The application name. Optional (defaults to "crystal").
getter application_name : String
getter auth_methods : Array(String) = %w[scram-sha-256-plus scram-sha-256 md5]
# Create a new ConnInfo from all parts
def initialize(host : String? = nil, database : String? = nil, user : String? = nil, password : String? = nil, port : Int | String? = nil, sslmode : String | Symbol? = nil, application_name : String? = nil)
@host = default_host host
db = default_database database
@database = db.lchop('/')
@user = default_user user
@port = (port || ENV.fetch("PGPORT", "5432")).to_i
@sslmode = default_sslmode sslmode
@password = password || ENV.fetch("PGPASSWORD", PgPass.locate(@host, @port, @database, @user))
@application_name = default_application_name application_name
end
# Initialize with either "postgres://" urls or postgres "key=value" pairs
def self.from_conninfo_string(conninfo : String)
if conninfo.starts_with?("postgres://") || conninfo.starts_with?("postgresql://")
new(URI.parse(conninfo))
else
return new if conninfo == ""
args = Hash(String, String).new
conninfo.split ' ' do |pair|
begin
k, eq, v = pair.partition('=')
if eq.empty?
raise ArgumentError.new("invalid paramater: #{pair}")
end
args[k] = v
end
end
new(args)
end
end
# Initializes with a `URI`
def initialize(uri : URI)
params = URI::Params.parse(uri.query.to_s)
hostname = uri.hostname.presence || params.fetch("host", "")
initialize(hostname, uri.path, uri.user, uri.password, uri.port, :prefer, params.fetch("application_name", nil))
if q = uri.query
HTTP::Params.parse(q) do |key, value|
handle_sslparam(key, value)
end
end
end
# Initialize with a `Hash`
#
# Valid keys match Postgres "conninfo" keys and are `"host"`, `"dbname"`,
# `"user"`, `"password"`, `"port"`, `"sslmode"`, `"sslcert"`, `"sslkey"`,
# `"sslrootcert"` and `"application_name"`.
def initialize(params : Hash)
initialize(params["host"]?, params["dbname"]?, params["user"]?,
params["password"]?, params["port"]?, params["sslmode"]?, params["application_name"]?)
params.each do |key, value|
handle_sslparam(key, value)
end
end
private def handle_sslparam(key : String, value : String)
case key
when "sslmode"
@sslmode = default_sslmode value
when "sslcert"
@sslcert = value
when "sslkey"
@sslkey = value
when "sslrootcert"
@sslrootcert = value
when "auth_methods"
methods = value.split(",").compact_map(&.underscore.presence).uniq
methods.each do |method|
unless method.in?(SUPPORTED_AUTH_METHODS)
raise "invalid auth_method #{method}"
end
end
@auth_methods = methods
else
# ignore
end
end
private def default_host(h)
return h if h && !h.blank?
if pghost = ENV["PGHOST"]?
return pghost[0] == '/' ? "#{pghost}/.s.PGSQL.5432" : pghost
end
SOCKET_SEARCH.each do |s|
return s if File.exists?(s)
end
"localhost"
end
private def default_database(db)
if db && db != "/" && !db.empty?
db
else
ENV.fetch("PGDATABASE", current_user_name)
end
end
private def default_application_name(application_name, fallback_application_name = "crystal")
application_name || ENV.fetch("PGAPPNAME", nil) || fallback_application_name
end
private def default_user(u)
u || ENV.fetch("PGUSER", current_user_name)
end
private def default_sslmode(mode)
case mode
when nil, :prefer, "prefer"
:prefer
when :disable, "disable"
:disable
when :allow, "allow"
:allow
when :require, "require"
:require
when :"verify-ca", "verify-ca"
:"verify-ca"
when :"verify-full", "verify-full"
:"verify-full"
else
raise ArgumentError.new("sslmode #{mode} not supported")
end
end
private def current_user_name
{% if flag?(:windows) %}
# NOTE: actually getting the current username on windows would be better
# https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-getusernamew
"postgres"
{% else %}
System::User.find_by(id: LibC.getuid.to_s).username
{% end %}
end
end
end