Skip to content

Commit

Permalink
allow only one sender to connect
Browse files Browse the repository at this point in the history
  • Loading branch information
RCL98 committed Nov 13, 2023
1 parent e5ba8c5 commit ee927bd
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 47 deletions.
5 changes: 3 additions & 2 deletions src/croc/croc.go
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ func (c *Client) transferOverLocalRelay(errchan chan error) {
// not really an error because it will try to connect over the actual relay
return
}
log.Debugf("local connection established: %+v", conn)
log.Debugf("local sender connection established: %+v", conn)
err = nil
for {
data, errConn := conn.Receive()
Expand Down Expand Up @@ -800,7 +800,7 @@ func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, t
errchan <- err
} else {
log.Debugf("banner: %s", banner)
log.Debugf("connection established: %+v", conn)
log.Debugf("sender connection established: %+v", conn)

c.listenToMainConn(conn, ipaddr, banner, errchan)
}
Expand Down Expand Up @@ -939,6 +939,7 @@ func (c *Client) Receive() (err error) {
c.conn[0], banner, c.ExternalIP, err = tcp.ConnectToTCPServer(address, c.Options.RelayPassword, c.Options.SharedSecret[:3], c.Options.IsSender, true, 1, time.Duration(c.Options.TimeLimit)*time.Second)
if err == nil {
c.Options.RelayAddress = address
log.Debug("receiver connection established")
break
}
log.Debugf("could not establish '%s'", address)
Expand Down
6 changes: 3 additions & 3 deletions src/croc/croc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func TestCrocLocal(t *testing.T) {
IsSender: true,
TimeLimit: 30,
MaxTransfers: 1,
SharedSecret: "8123-testingthecroc",
SharedSecret: "2813-testingthecroc",
Debug: true,
RelayAddress: "127.0.0.1:8181",
RelayPorts: []string{"8181", "8182"},
Expand All @@ -297,7 +297,7 @@ func TestCrocLocal(t *testing.T) {
log.Debug("setting up receiver")
receiver, err := New(Options{
IsSender: false,
SharedSecret: "8123-testingthecroc",
SharedSecret: "2813-testingthecroc",
Debug: true,
RelayAddress: "127.0.0.1:8181",
RelayPassword: "pass123",
Expand Down Expand Up @@ -325,7 +325,7 @@ func TestCrocLocal(t *testing.T) {
}
wg.Done()
}()
time.Sleep(100 * time.Millisecond)
time.Sleep(300 * time.Millisecond)
go func() {
err := receiver.Receive()
if err != nil {
Expand Down
96 changes: 54 additions & 42 deletions src/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ func (s *server) run() (err error) {
log.Debugf("checking connection of room %s for %+v", room, c)
deleteIt := false
s.rooms.Lock()
if _, ok := s.rooms.rooms[room]; !ok {
log.Debug("room is gone")
if _, ok := s.rooms.rooms[room]; !ok || (s.rooms.rooms[room].sender == nil && s.rooms.rooms[room].receiver == nil) {
log.Debugf("room %s is gone", room)
s.rooms.Unlock()
return
}
Expand Down Expand Up @@ -312,6 +312,12 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
_, roomExists := s.rooms.rooms[room]
// create the room if it is new
if !roomExists || isSender {
if roomExists && isSender && s.rooms.rooms[room].sender != nil {
// if the room exists and the sender is already connected
// then signal to the client that the room is full
err = s.sendRoomIsFull(c, strongKeyForEncryption)
return
}
err = s.createOrUpdateRoom(c, room, strongKeyForEncryption, isMainRoom, isSender, roomExists)
if err != nil {
log.Error(err)
Expand Down Expand Up @@ -378,6 +384,12 @@ func (s *server) sendRoomIsFull(c *comm.Comm, strongKeyForEncryption []byte) (er
func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncryption []byte, isMainRoom, isSender, updateRoom bool) (err error) {
var enc, data, bSend []byte

if !updateRoom {
log.Debugf("Creating room %s", room)
} else {
log.Debugf("Updating room %s", room)
}

maxTransfers := 1
if isMainRoom && isSender {
log.Debug("Wait for maxTransfers")
Expand Down Expand Up @@ -447,13 +459,6 @@ func (s *server) createOrUpdateRoom(c *comm.Comm, room string, strongKeyForEncry
return
}

func removeReceiverFromQueue(queue *list.List, c *comm.Comm) {
var rmElem *list.Element
for rmElem = queue.Front(); rmElem.Value != c.ID(); rmElem = rmElem.Next() {
}
queue.Remove(rmElem)
}

func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error, keepGoing bool) {
var bSend []byte
log.Debugf("room %s is full, adding to queue", room)
Expand All @@ -477,24 +482,10 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
for {
s.rooms.roomLocks[room].Lock()

if s.rooms.rooms[room].doneTransfers >= s.rooms.rooms[room].maxTransfers {
// remove the client from the queue
newQueue := s.rooms.rooms[room].queue
removeReceiverFromQueue(newQueue, c)
s.rooms.Lock()
s.rooms.rooms[room] = roomInfo{
sender: s.rooms.rooms[room].sender,
receiver: s.rooms.rooms[room].receiver,
isMainRoom: s.rooms.rooms[room].isMainRoom,
opened: s.rooms.rooms[room].opened,
maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: s.rooms.rooms[room].doneTransfers,
queue: newQueue,
}
s.rooms.Unlock()

// tell the client that the sender is no longer available
bSend, err = crypt.Encrypt([]byte(senderGone), strongKeyForEncryption)
if s.rooms.rooms[room].receiver != nil || s.rooms.rooms[room].queue.Front().Value.(string) != c.ID() {
time.Sleep(1 * time.Second)
// tell the client that they need to wait
bSend, err = crypt.Encrypt([]byte("wait"), strongKeyForEncryption)
if err != nil {
return
}
Expand All @@ -503,11 +494,10 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
log.Error(err)
return
}
break
} else if s.rooms.rooms[room].receiver != nil || s.rooms.rooms[room].queue.Front().Value.(string) != c.ID() {
time.Sleep(1 * time.Second)
// tell the client that they need to wait
bSend, err = crypt.Encrypt([]byte("wait"), strongKeyForEncryption)
s.rooms.roomLocks[room].Unlock()
} else if s.rooms.rooms[room].doneTransfers >= s.rooms.rooms[room].maxTransfers {
// tell the client that the sender is no longer available
bSend, err = crypt.Encrypt([]byte(senderGone), strongKeyForEncryption)
if err != nil {
return
}
Expand All @@ -516,12 +506,12 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
log.Error(err)
return
}
s.rooms.roomLocks[room].Unlock()
break
} else {
s.rooms.Lock()
// remove the client from the queue
newQueue := s.rooms.rooms[room].queue
removeReceiverFromQueue(newQueue, c)
newQueue.Remove(newQueue.Front())
s.rooms.rooms[room] = roomInfo{
sender: s.rooms.rooms[room].sender,
receiver: c,
Expand All @@ -541,6 +531,12 @@ func (s *server) handleWaitingRoomForReceivers(c *comm.Comm, room string, strong
func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption []byte) (err error) {
s.rooms.Unlock()

// safety check (it should never happen)
if s.rooms.rooms[room].sender == nil || s.rooms.rooms[room].receiver == nil {
err = fmt.Errorf("sender or receiver is nil")
return
}

// second connection is the sender, time to staple connections
var wg sync.WaitGroup
wg.Add(1)
Expand Down Expand Up @@ -587,8 +583,8 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
s.rooms.Unlock()

// delete the room if it is done or unlock it if it is not
if newDoneTransfers == s.rooms.rooms[room].maxTransfers {
log.Debugf("room %s is done", room)
if newDoneTransfers >= s.rooms.rooms[room].maxTransfers {
log.Debugf("room %s is done, deleting it", room)
s.deleteRoom(room)
} else {
log.Debugf("room %s has %d done", room, newDoneTransfers)
Expand All @@ -598,6 +594,8 @@ func (s *server) beginTransfer(c *comm.Comm, room string, strongKeyForEncryption
}

func (s *server) deleteRoom(room string) {
s.rooms.Lock()
defer s.rooms.Unlock()
if _, ok := s.rooms.rooms[room]; !ok {
return
}
Expand All @@ -609,11 +607,21 @@ func (s *server) deleteRoom(room string) {
break
}
s.rooms.roomLocks[room].Lock()
// remove the client from the queue
newQueue := s.rooms.rooms[room].queue
newQueue.Remove(newQueue.Front())
s.rooms.rooms[room] = roomInfo{
sender: s.rooms.rooms[room].sender,
receiver: s.rooms.rooms[room].receiver,
isMainRoom: s.rooms.rooms[room].isMainRoom,
opened: s.rooms.rooms[room].opened,
maxTransfers: s.rooms.rooms[room].maxTransfers,
doneTransfers: s.rooms.rooms[room].doneTransfers,
queue: newQueue,
}
}
delete(s.rooms.roomLocks, room)
}
s.rooms.Lock()
defer s.rooms.Unlock()
log.Debugf("deleting room: %s", room)
if s.rooms.rooms[room].sender != nil {
s.rooms.rooms[room].sender.Close()
Expand Down Expand Up @@ -871,6 +879,13 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
log.Debug(err)
return
}

if bytes.Equal(data, []byte(fullRoom)) {
err = fmt.Errorf("room is full")
c = nil
return
}

if !isSender {
if bytes.Equal(data, []byte("wait")) {
log.Debug("waiting for sender to be free")
Expand All @@ -880,16 +895,13 @@ func ConnectToTCPServer(address, password, room string, isSender, isMainRoom boo
err = fmt.Errorf("sender is gone")
c = nil
return
} else if bytes.Equal(data, []byte(fullRoom)) {
err = fmt.Errorf("room is full")
c = nil
return
} else if bytes.Equal(data, []byte(noRoom)) {
err = fmt.Errorf("room does not exist")
c = nil
return
}
}

if !bytes.Equal(data, []byte("ok")) {
err = fmt.Errorf("got bad response: %s", data)
log.Debug(err)
Expand Down
16 changes: 16 additions & 0 deletions src/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ func TestTCPServerPing(t *testing.T) {
assert.NotNil(t, err)
}

func TestOnlyOneSenderPerRoom(t *testing.T) {
log.SetLevel("error")
go Run("debug", "127.0.0.1", "8381", "pass123", "8382")
time.Sleep(100 * time.Millisecond)

c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute)
assert.Nil(t, err)
assert.NotNil(t, c1)
assert.Equal(t, banner, "8382")

c2, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", true, true, 1, 1*time.Minute)
assert.NotNil(t, err)
assert.True(t, strings.Contains(err.Error(), "room is full"))
assert.Nil(t, c2)
}

// This is helper function to test that a mocks a transfer
// between two clients connected to the server,
// and checks that the data is transferred correctly
Expand Down

0 comments on commit ee927bd

Please sign in to comment.