@@ -20,11 +20,12 @@ use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync:
20
20
21
21
use anyhow:: { anyhow, bail, Context , Result } ;
22
22
use derive_more:: Debug ;
23
- use futures_lite:: StreamExt ;
23
+ use futures_lite:: { future :: Boxed , StreamExt } ;
24
24
use http:: {
25
25
response:: Builder as ResponseBuilder , HeaderMap , Method , Request , Response , StatusCode ,
26
26
} ;
27
27
use hyper:: body:: Incoming ;
28
+ use iroh_base:: NodeId ;
28
29
#[ cfg( feature = "test-utils" ) ]
29
30
use iroh_base:: RelayUrl ;
30
31
use iroh_metrics:: inc;
@@ -120,6 +121,40 @@ pub struct RelayConfig<EC: fmt::Debug, EA: fmt::Debug = EC> {
120
121
pub limits : Limits ,
121
122
/// Key cache capacity.
122
123
pub key_cache_capacity : Option < usize > ,
124
+ /// Access configuration.
125
+ pub access : AccessConfig ,
126
+ }
127
+
128
+ /// Controls which nodes are allowed to use the relay.
129
+ #[ derive( derive_more:: Debug ) ]
130
+ pub enum AccessConfig {
131
+ /// Everyone
132
+ Everyone ,
133
+ /// Only nodes for which the function returns `Access::Allow`.
134
+ #[ debug( "restricted" ) ]
135
+ Restricted ( Box < dyn Fn ( NodeId ) -> Boxed < Access > + Send + Sync + ' static > ) ,
136
+ }
137
+
138
+ impl AccessConfig {
139
+ /// Is this node allowed?
140
+ pub async fn is_allowed ( & self , node : NodeId ) -> bool {
141
+ match self {
142
+ Self :: Everyone => true ,
143
+ Self :: Restricted ( check) => {
144
+ let res = check ( node) . await ;
145
+ matches ! ( res, Access :: Allow )
146
+ }
147
+ }
148
+ }
149
+ }
150
+
151
+ /// Access restriction for a node.
152
+ #[ derive( Debug , Copy , Clone , PartialEq , Eq ) ]
153
+ pub enum Access {
154
+ /// Access is allowed.
155
+ Allow ,
156
+ /// Access is denied.
157
+ Deny ,
123
158
}
124
159
125
160
/// Configuration for the STUN server.
@@ -318,6 +353,7 @@ impl Server {
318
353
let mut builder = http_server:: ServerBuilder :: new ( relay_bind_addr)
319
354
. headers ( headers)
320
355
. key_cache_capacity ( key_cache_capacity)
356
+ . access ( relay_config. access )
321
357
. request_handler ( Method :: GET , "/" , Box :: new ( root_handler) )
322
358
. request_handler ( Method :: GET , "/index.html" , Box :: new ( root_handler) )
323
359
. request_handler ( Method :: GET , RELAY_PROBE_PATH , Box :: new ( probe_handler) )
@@ -772,6 +808,7 @@ mod tests {
772
808
use std:: { net:: Ipv4Addr , time:: Duration } ;
773
809
774
810
use bytes:: Bytes ;
811
+ use futures_lite:: FutureExt ;
775
812
use futures_util:: SinkExt ;
776
813
use http:: header:: UPGRADE ;
777
814
use iroh_base:: { NodeId , SecretKey } ;
@@ -790,6 +827,7 @@ mod tests {
790
827
tls : None ,
791
828
limits : Default :: default ( ) ,
792
829
key_cache_capacity : Some ( 1024 ) ,
830
+ access : AccessConfig :: Everyone ,
793
831
} ) ,
794
832
quic : None ,
795
833
stun : None ,
@@ -840,6 +878,7 @@ mod tests {
840
878
tls : None ,
841
879
limits : Default :: default ( ) ,
842
880
key_cache_capacity : Some ( 1024 ) ,
881
+ access : AccessConfig :: Everyone ,
843
882
} ) ,
844
883
stun : None ,
845
884
quic : None ,
@@ -1106,4 +1145,96 @@ mod tests {
1106
1145
assert_eq ! ( txid, txid_back) ;
1107
1146
assert_eq ! ( response_addr, socket. local_addr( ) . unwrap( ) ) ;
1108
1147
}
1148
+
1149
+ #[ tokio:: test]
1150
+ async fn test_relay_access_control ( ) -> Result < ( ) > {
1151
+ let _guard = iroh_test:: logging:: setup ( ) ;
1152
+
1153
+ let a_secret_key = SecretKey :: generate ( rand:: thread_rng ( ) ) ;
1154
+ let a_key = a_secret_key. public ( ) ;
1155
+
1156
+ let server = Server :: spawn ( ServerConfig :: < ( ) , ( ) > {
1157
+ relay : Some ( RelayConfig :: < ( ) , ( ) > {
1158
+ http_bind_addr : ( Ipv4Addr :: LOCALHOST , 0 ) . into ( ) ,
1159
+ tls : None ,
1160
+ limits : Default :: default ( ) ,
1161
+ key_cache_capacity : Some ( 1024 ) ,
1162
+ access : AccessConfig :: Restricted ( Box :: new ( move |node_id| {
1163
+ async move {
1164
+ info ! ( "checking {}" , node_id) ;
1165
+ // reject node a
1166
+ if node_id == a_key {
1167
+ Access :: Deny
1168
+ } else {
1169
+ Access :: Allow
1170
+ }
1171
+ }
1172
+ . boxed ( )
1173
+ } ) ) ,
1174
+ } ) ,
1175
+ quic : None ,
1176
+ stun : None ,
1177
+ metrics_addr : None ,
1178
+ } )
1179
+ . await
1180
+ . unwrap ( ) ;
1181
+ let relay_url = format ! ( "http://{}" , server. http_addr( ) . unwrap( ) ) ;
1182
+ let relay_url: RelayUrl = relay_url. parse ( ) ?;
1183
+
1184
+ // set up client a
1185
+ let resolver = crate :: dns:: default_resolver ( ) . clone ( ) ;
1186
+ let mut client_a = ClientBuilder :: new ( relay_url. clone ( ) , a_secret_key, resolver)
1187
+ . connect ( )
1188
+ . await ?;
1189
+
1190
+ // the next message should be the rejection of the connection
1191
+ tokio:: time:: timeout ( Duration :: from_millis ( 500 ) , async move {
1192
+ match client_a. next ( ) . await . unwrap ( ) . unwrap ( ) {
1193
+ ReceivedMessage :: Health { problem } => {
1194
+ assert_eq ! ( problem, Some ( "not authenticated" . to_string( ) ) ) ;
1195
+ }
1196
+ msg => {
1197
+ panic ! ( "other msg: {:?}" , msg) ;
1198
+ }
1199
+ }
1200
+ } )
1201
+ . await ?;
1202
+
1203
+ // test that another client has access
1204
+
1205
+ // set up client b
1206
+ let b_secret_key = SecretKey :: generate ( rand:: thread_rng ( ) ) ;
1207
+ let b_key = b_secret_key. public ( ) ;
1208
+
1209
+ let resolver = crate :: dns:: default_resolver ( ) . clone ( ) ;
1210
+ let mut client_b = ClientBuilder :: new ( relay_url. clone ( ) , b_secret_key, resolver)
1211
+ . connect ( )
1212
+ . await ?;
1213
+
1214
+ // set up client c
1215
+ let c_secret_key = SecretKey :: generate ( rand:: thread_rng ( ) ) ;
1216
+ let c_key = c_secret_key. public ( ) ;
1217
+
1218
+ let resolver = crate :: dns:: default_resolver ( ) . clone ( ) ;
1219
+ let mut client_c = ClientBuilder :: new ( relay_url. clone ( ) , c_secret_key, resolver)
1220
+ . connect ( )
1221
+ . await ?;
1222
+
1223
+ // send message from b to c
1224
+ let msg = Bytes :: from ( "hello, c" ) ;
1225
+ let res = try_send_recv ( & mut client_b, & mut client_c, c_key, msg. clone ( ) ) . await ?;
1226
+
1227
+ if let ReceivedMessage :: ReceivedPacket {
1228
+ remote_node_id,
1229
+ data,
1230
+ } = res
1231
+ {
1232
+ assert_eq ! ( b_key, remote_node_id) ;
1233
+ assert_eq ! ( msg, data) ;
1234
+ } else {
1235
+ panic ! ( "client_c received unexpected message {res:?}" ) ;
1236
+ }
1237
+
1238
+ Ok ( ( ) )
1239
+ }
1109
1240
}
0 commit comments