Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cookies #38

Merged
merged 7 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class DefaultRakServerConfig extends DefaultChannelConfig implements RakS
private volatile int packetLimit = RakConstants.DEFAULT_PACKET_LIMIT;
private volatile int globalPacketLimit = RakConstants.DEFAULT_GLOBAL_PACKET_LIMIT;
private volatile int unconnectedPacketLimit = RakConstants.DEFAULT_OFFLINE_PACKET_LIMIT;
private volatile boolean sendCookie;

public DefaultRakServerConfig(RakServerChannel channel) {
super(channel);
Expand Down Expand Up @@ -98,6 +99,9 @@ public <T> T getOption(ChannelOption<T> option) {
if (option == RakChannelOption.RAK_OFFLINE_PACKET_LIMIT) {
return (T) Integer.valueOf(this.getUnconnectedPacketLimit());
}
if (option == RakChannelOption.RAK_SEND_COOKIE) {
return (T) Boolean.valueOf(this.sendCookie);
}
return this.channel.parent().config().getOption(option);
}

Expand Down Expand Up @@ -129,6 +133,8 @@ public <T> boolean setOption(ChannelOption<T> option, T value) {
this.setUnconnectedPacketLimit((Integer) value);
} else if (option == RakChannelOption.RAK_GLOBAL_PACKET_LIMIT) {
this.setGlobalPacketLimit((Integer) value);
} else if (option == RakChannelOption.RAK_SEND_COOKIE) {
this.sendCookie = (Boolean) value;
} else {
return this.channel.parent().config().setOption(option, value);
}
Expand Down Expand Up @@ -255,6 +261,11 @@ public int getPacketLimit() {
return this.packetLimit;
}

@Override
public boolean getSendCookie() {
return this.sendCookie;
}

@Override
public int getUnconnectedPacketLimit() {
return unconnectedPacketLimit;
Expand All @@ -274,4 +285,9 @@ public int getGlobalPacketLimit() {
public void setGlobalPacketLimit(int globalPacketLimit) {
this.globalPacketLimit = globalPacketLimit;
}

@Override
public void setSendCookie(boolean sendCookie) {
this.sendCookie = sendCookie;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ public class RakChannelOption<T> extends ChannelOption<T> {
public static final ChannelOption<Integer> RAK_GLOBAL_PACKET_LIMIT =
valueOf(RakChannelOption.class, "RAK_GLOBAL_PACKET_LIMIT");

/**
* Whether to send a cookie to the client during the connection process.
*/
public static final ChannelOption<Boolean> RAK_SEND_COOKIE =
Kas-tle marked this conversation as resolved.
Show resolved Hide resolved
valueOf(RakChannelOption.class, "RAK_SEND_COOKIE");

@SuppressWarnings("deprecation")
protected RakChannelOption() {
super(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,8 @@ public interface RakServerChannelConfig extends ChannelConfig {
int getUnconnectedPacketLimit();

void setUnconnectedPacketLimit(int limit);

boolean getSendCookie();

void setSendCookie(boolean sendCookie);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.channel.*;
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.util.concurrent.ScheduledFuture;
import org.cloudburstmc.netty.channel.raknet.RakChannel;
Expand All @@ -45,6 +44,8 @@ public class RakClientOfflineHandler extends SimpleChannelInboundHandler<ByteBuf

private RakOfflineState state = RakOfflineState.HANDSHAKE_1;
private int connectionAttempts;
private int cookie;
private boolean security;

public RakClientOfflineHandler(RakChannel rakChannel, ChannelPromise promise) {
this.rakChannel = rakChannel;
Expand Down Expand Up @@ -153,24 +154,32 @@ protected void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Excep
private void onOpenConnectionReply1(ChannelHandlerContext ctx, ByteBuf buffer) {
long serverGuid = buffer.readLong();
boolean security = buffer.readBoolean();
int mtu = buffer.readShort();
if (security) {
this.successPromise.tryFailure(new SecurityException());
return;
this.cookie = buffer.readInt();
this.security = true;
}
int mtu = buffer.readShort();

this.rakChannel.config().setOption(RakChannelOption.RAK_MTU, mtu);
this.rakChannel.config().setOption(RakChannelOption.RAK_REMOTE_GUID, serverGuid);

this.state = RakOfflineState.HANDSHAKE_2;
this.sendOpenConnectionRequest2(ctx.channel());
if (this.security) {
this.sendOpenConnectionRequest2(ctx.channel(), this.cookie);
} else {
this.sendOpenConnectionRequest2(ctx.channel());
}
Kas-tle marked this conversation as resolved.
Show resolved Hide resolved
}

private void onOpenConnectionReply2(ChannelHandlerContext ctx, ByteBuf buffer) {
buffer.readLong(); // serverGuid
RakUtils.readAddress(buffer); // serverAddress
int mtu = buffer.readShort();
buffer.readBoolean(); // security
boolean security = buffer.readBoolean(); // security
if (security) {
this.successPromise.tryFailure(new SecurityException());
return;
}

this.rakChannel.config().setOption(RakChannelOption.RAK_MTU, mtu);
this.state = RakOfflineState.HANDSHAKE_COMPLETED;
Expand Down Expand Up @@ -209,6 +218,21 @@ private void sendOpenConnectionRequest2(Channel channel) {
channel.writeAndFlush(request);
}

private void sendOpenConnectionRequest2(Channel channel, int cookie) {
int mtuSize = this.rakChannel.config().getOption(RakChannelOption.RAK_MTU);
ByteBuf magicBuf = this.rakChannel.config().getOption(RakChannelOption.RAK_UNCONNECTED_MAGIC);

ByteBuf request = channel.alloc().ioBuffer(39);
request.writeByte(ID_OPEN_CONNECTION_REQUEST_2);
request.writeBytes(magicBuf, magicBuf.readerIndex(), magicBuf.readableBytes());
request.writeInt(cookie);
request.writeBoolean(false); // Client wrote challenge
RakUtils.writeAddress(request, (InetSocketAddress) channel.remoteAddress());
request.writeShort(mtuSize);
request.writeLong(this.rakChannel.config().getOption(RakChannelOption.RAK_GUID));
channel.writeAndFlush(request);
}

private static void safeCancel(ScheduledFuture<?> future, Channel channel) {
channel.eventLoop().execute(() -> { // Make sure this is not called at two places at the same time
if (!future.isCancelled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -47,6 +48,8 @@ public class RakServerOfflineHandler extends AdvancedChannelInboundHandler<Datag

private static final InternalLogger log = InternalLoggerFactory.getInstance(RakServerOfflineHandler.class);

private final ThreadLocal<SecureRandom> random = ThreadLocal.withInitial(SecureRandom::new);

private final ExpiringMap<InetSocketAddress, Integer> pendingConnections = ExpiringMap.builder()
.expiration(10, TimeUnit.SECONDS)
.expirationPolicy(ExpirationPolicy.CREATED)
Expand All @@ -58,6 +61,12 @@ public class RakServerOfflineHandler extends AdvancedChannelInboundHandler<Datag
.expirationPolicy(ExpirationPolicy.CREATED)
.build();

private final ExpiringMap<InetSocketAddress, Integer> cookies = ExpiringMap.builder()
.expiration(10, TimeUnit.SECONDS)
Kas-tle marked this conversation as resolved.
Show resolved Hide resolved
.expirationPolicy(ExpirationPolicy.CREATED)
.expirationListener((key, value) -> ReferenceCountUtil.release(value))
.build();

private final RakServerChannel channel;

public RakServerOfflineHandler(RakServerChannel channel) {
Expand Down Expand Up @@ -176,11 +185,19 @@ private void onOpenConnectionRequest1(ChannelHandlerContext ctx, DatagramPacket
log.trace("Received duplicate open connection request 1 from {}", sender);
}

ByteBuf replyBuffer = ctx.alloc().ioBuffer(28, 28);
boolean sendCookie = ctx.channel().config().getOption(RakChannelOption.RAK_SEND_COOKIE);
int bufferCapacity = sendCookie ? 32 : 28; // 4 byte cookie

ByteBuf replyBuffer = ctx.alloc().ioBuffer(bufferCapacity, bufferCapacity);
replyBuffer.writeByte(ID_OPEN_CONNECTION_REPLY_1);
replyBuffer.writeBytes(magicBuf, magicBuf.readerIndex(), magicBuf.readableBytes());
replyBuffer.writeLong(guid);
replyBuffer.writeBoolean(false); // Security
replyBuffer.writeBoolean(sendCookie); // Security
if (sendCookie) {
int cookie = this.random.get().nextInt();
this.cookies.put(sender, cookie);
replyBuffer.writeInt(cookie);
}
replyBuffer.writeShort(RakUtils.clamp(mtu, ctx.channel().config().getOption(RakChannelOption.RAK_MIN_MTU), ctx.channel().config().getOption(RakChannelOption.RAK_MAX_MTU)));
ctx.writeAndFlush(new DatagramPacket(replyBuffer, sender));
}
Expand All @@ -191,6 +208,21 @@ private void onOpenConnectionRequest2(ChannelHandlerContext ctx, DatagramPacket
// Skip already verified magic
buffer.skipBytes(magicBuf.readableBytes());

boolean sendCookie = ctx.channel().config().getOption(RakChannelOption.RAK_SEND_COOKIE);
if (sendCookie) {
int cookie = buffer.readInt();
Integer expectedCookie = this.cookies.remove(sender);
if (expectedCookie == null || expectedCookie != cookie) {
if (log.isTraceEnabled()) {
log.trace("Received open connection request 2 from {} with invalid cookie (expected {}, but received {})", sender, expectedCookie, cookie);
}
// Incorrect cookie provided
// This is likely source IP spoofing so we will not reply
return;
}
buffer.readBoolean(); // Client wrote challenge
}

Integer version = this.pendingConnections.remove(sender);
if (version == null) {
// We can't determine the version without the previous request, so assume it's the wrong version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static InetSocketAddress readAddress(ByteBuf buffer) {
int scopeId = buffer.readInt();
address = Inet6Address.getByAddress(null, addressBytes, scopeId);
} else {
throw new UnsupportedOperationException("Unknown Internet Protocol version.");
throw new UnsupportedOperationException("Unknown Internet Protocol version. Expected 4 or 6, got " + type);
}
} catch (UnknownHostException e) {
throw new IllegalArgumentException(e);
Expand Down
58 changes: 42 additions & 16 deletions transport-raknet/src/test/java/org/cloudburstmc/netty/RakTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,20 @@ private static ServerBootstrap serverBootstrap() {
.option(RakChannelOption.RAK_MAX_CONNECTIONS, 1)
.childOption(RakChannelOption.RAK_ORDERING_CHANNELS, 1)
.option(RakChannelOption.RAK_GUID, ThreadLocalRandom.current().nextLong())
.option(RakChannelOption.RAK_ADVERTISEMENT, Unpooled.wrappedBuffer(ADVERTISEMENT));
.option(RakChannelOption.RAK_ADVERTISEMENT, Unpooled.wrappedBuffer(ADVERTISEMENT))
.handler(new ChannelInitializer<RakServerChannel>() {
@Override
protected void initChannel(RakServerChannel ch) throws Exception {
System.out.println("Initialised server channel");
}
})
.childHandler(new ChannelInitializer<RakChildChannel>() {
@Override
protected void initChannel(RakChildChannel ch) throws Exception {
System.out.println("Server child channel initialized " + ch.remoteAddress());
ch.pipeline().addLast(RESEND_HANDLER());
}
});
}

private static Bootstrap clientBootstrap(int mtu) {
Expand All @@ -117,32 +130,44 @@ private static IntStream validMtu() {
.filter(i -> i % 12 == 0);
}

@BeforeEach
public void setupServer() {
serverBootstrap()
.handler(new ChannelInitializer<RakServerChannel>() {
@Override
protected void initChannel(RakServerChannel ch) throws Exception {
System.out.println("Initialised server channel");
}
})
.childHandler(new ChannelInitializer<RakChildChannel>() {
@Override
protected void initChannel(RakChildChannel ch) throws Exception {
System.out.println("Server child channel initialized " + ch.remoteAddress());
ch.pipeline().addLast(RESEND_HANDLER());
}
})
.bind(new InetSocketAddress("127.0.0.1", 19132))
.awaitUninterruptibly();
}

public void setupCookieServer() {
serverBootstrap()
.option(RakChannelOption.RAK_SEND_COOKIE, true)
.bind(new InetSocketAddress("127.0.0.1", 19132))
.awaitUninterruptibly();
}

@Test
public void testClientConnect() {
setupServer();
int mtu = RakConstants.MAXIMUM_MTU_SIZE;
System.out.println("Testing client with MTU " + mtu);

Channel channel = clientBootstrap(mtu)
clientBootstrap(mtu)
.handler(new ChannelInitializer<RakClientChannel>() {
@Override
protected void initChannel(RakClientChannel ch) throws Exception {
System.out.println("Client channel initialized");
}
})
.connect(new InetSocketAddress("127.0.0.1", 19132))
.awaitUninterruptibly()
.channel();
}

@Test
public void testClientConnectWithCookie() {
setupCookieServer();
int mtu = RakConstants.MAXIMUM_MTU_SIZE;
System.out.println("Testing client with MTU " + mtu + " and cookie enabled");

clientBootstrap(mtu)
.handler(new ChannelInitializer<RakClientChannel>() {
@Override
protected void initChannel(RakClientChannel ch) throws Exception {
Expand All @@ -158,6 +183,7 @@ protected void initChannel(RakClientChannel ch) throws Exception {
@ParameterizedTest
@MethodSource("validMtu")
public void testClientResend(int mtu) {
setupServer();
System.out.println("Testing client with MTU " + mtu);

SecureRandom random = new SecureRandom();
Expand Down