Skip to content

feat: allow skipping upgrade steps for incoming connections #1502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/keychain/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ export interface KeyChainComponents {
*/
export class KeyChain implements Startable {
private readonly components: KeyChainComponents
private init: KeyChainInit
private readonly init: KeyChainInit
private started: boolean

/**
Expand Down
56 changes: 37 additions & 19 deletions src/upgrader.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class DefaultUpgrader extends EventEmitter<UpgraderEvents> implements Upg
/**
* Upgrades an inbound connection
*/
async upgradeInbound (maConn: MultiaddrConnection): Promise<Connection> {
async upgradeInbound (maConn: MultiaddrConnection, opts?: UpgraderOptions): Promise<Connection> {
const accept = await this.components.connectionManager.acceptIncomingConnection(maConn)

if (!accept) {
Expand Down Expand Up @@ -166,38 +166,56 @@ export class DefaultUpgrader extends EventEmitter<UpgraderEvents> implements Upg

// Protect
let protectedConn = maConn
const protector = this.components.connectionProtector

if (protector != null) {
log('protecting the inbound connection')
protectedConn = await protector.protect(maConn)
if (opts?.skipProtection !== true) {
const protector = this.components.connectionProtector

if (protector != null) {
log('protecting the inbound connection')
protectedConn = await protector.protect(maConn)
}
}

try {
// Encrypt the connection
({
conn: encryptedConn,
remotePeer,
protocol: cryptoProtocol
} = await this._encryptInbound(protectedConn))
encryptedConn = protectedConn
if (opts?.skipEncryption !== true) {
({
conn: encryptedConn,
remotePeer,
protocol: cryptoProtocol
} = await this._encryptInbound(protectedConn))

if (await this.components.connectionGater.denyInboundEncryptedConnection(remotePeer, {
...protectedConn,
...encryptedConn
})) {
throw errCode(new Error('The multiaddr connection is blocked by gater.acceptEncryptedConnection'), codes.ERR_CONNECTION_INTERCEPTED)
}
} else {
const idStr = maConn.remoteAddr.getPeerId()

if (await this.components.connectionGater.denyInboundEncryptedConnection(remotePeer, {
...protectedConn,
...encryptedConn
})) {
throw errCode(new Error('The multiaddr connection is blocked by gater.acceptEncryptedConnection'), codes.ERR_CONNECTION_INTERCEPTED)
if (idStr == null) {
throw errCode(new Error('inbound connection that skipped encryption must have a peer id'), codes.ERR_INVALID_MULTIADDR)
}

const remotePeerId = peerIdFromString(idStr)

cryptoProtocol = 'native'
remotePeer = remotePeerId
}

// Multiplex the connection
if (this.muxers.size > 0) {
upgradedConn = encryptedConn
if (opts?.muxerFactory != null) {
muxerFactory = opts.muxerFactory
} else if (this.muxers.size > 0) {
// Multiplex the connection
const multiplexed = await this._multiplexInbound({
...protectedConn,
...encryptedConn
}, this.muxers)
muxerFactory = multiplexed.muxerFactory
upgradedConn = multiplexed.stream
} else {
upgradedConn = encryptedConn
}
} catch (err: any) {
log.error('Failed to upgrade inbound connection', err)
Expand Down
95 changes: 88 additions & 7 deletions test/upgrading/upgrader.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string'
import swarmKey from '../fixtures/swarm.key.js'
import { DefaultUpgrader } from '../../src/upgrader.js'
import { codes } from '../../src/errors.js'
import { mockConnectionGater, mockConnectionManager, mockMultiaddrConnPair, mockRegistrar, mockStream } from '@libp2p/interface-mocks'
import { mockConnectionGater, mockConnectionManager, mockMultiaddrConnPair, mockRegistrar, mockStream, mockMuxer } from '@libp2p/interface-mocks'
import Peers from '../fixtures/peers.js'
import type { Upgrader } from '@libp2p/interface-transport'
import type { PeerId } from '@libp2p/interface-peer-id'
import { createFromJSON } from '@libp2p/peer-id-factory'
import { plaintext } from '../../src/insecure/index.js'
import type { ConnectionEncrypter, SecuredConnection } from '@libp2p/interface-connection-encrypter'
import type { StreamMuxer, StreamMuxerFactory, StreamMuxerInit } from '@libp2p/interface-stream-muxer'
import type { Stream } from '@libp2p/interface-connection'
import type { ConnectionProtector, Stream } from '@libp2p/interface-connection'
import pDefer from 'p-defer'
import { createLibp2pNode, Libp2pNode } from '../../src/libp2p.js'
import { pEvent } from 'p-event'
Expand All @@ -32,6 +32,7 @@ import { Uint8ArrayList } from 'uint8arraylist'
import { PersistentPeerStore } from '@libp2p/peer-store'
import { MemoryDatastore } from 'datastore-core'
import { DefaultComponents } from '../../src/components.js'
import { StubbedInstance, stubInterface } from 'sinon-ts'

const addrs = [
multiaddr('/ip4/127.0.0.1/tcp/0'),
Expand All @@ -41,7 +42,12 @@ const addrs = [
describe('Upgrader', () => {
let localUpgrader: Upgrader
let localMuxerFactory: StreamMuxerFactory
let localConnectionEncrypter: ConnectionEncrypter
let localConnectionProtector: StubbedInstance<ConnectionProtector>
let remoteUpgrader: Upgrader
let remoteMuxerFactory: StreamMuxerFactory
let remoteConnectionEncrypter: ConnectionEncrypter
let remoteConnectionProtector: StubbedInstance<ConnectionProtector>
let localPeer: PeerId
let remotePeer: PeerId
let localComponents: DefaultComponents
Expand All @@ -56,39 +62,50 @@ describe('Upgrader', () => {
createFromJSON(Peers[1])
]))

localConnectionProtector = stubInterface<ConnectionProtector>()
localConnectionProtector.protect.resolvesArg(0)

localComponents = new DefaultComponents({
peerId: localPeer,
connectionGater: mockConnectionGater(),
registrar: mockRegistrar(),
datastore: new MemoryDatastore()
datastore: new MemoryDatastore(),
connectionProtector: localConnectionProtector
})
localComponents.peerStore = new PersistentPeerStore(localComponents)
localComponents.connectionManager = mockConnectionManager(localComponents)
localMuxerFactory = mplex()()
localConnectionEncrypter = plaintext()()
localUpgrader = new DefaultUpgrader(localComponents, {
connectionEncryption: [
plaintext()()
localConnectionEncrypter
],
muxers: [
localMuxerFactory
],
inboundUpgradeTimeout: 1000
})

remoteConnectionProtector = stubInterface<ConnectionProtector>()
remoteConnectionProtector.protect.resolvesArg(0)

remoteComponents = new DefaultComponents({
peerId: remotePeer,
connectionGater: mockConnectionGater(),
registrar: mockRegistrar(),
datastore: new MemoryDatastore()
datastore: new MemoryDatastore(),
connectionProtector: remoteConnectionProtector
})
remoteComponents.peerStore = new PersistentPeerStore(remoteComponents)
remoteComponents.connectionManager = mockConnectionManager(remoteComponents)
remoteMuxerFactory = mplex()()
remoteConnectionEncrypter = plaintext()()
remoteUpgrader = new DefaultUpgrader(remoteComponents, {
connectionEncryption: [
plaintext()()
remoteConnectionEncrypter
],
muxers: [
mplex()()
remoteMuxerFactory
],
inboundUpgradeTimeout: 1000
})
Expand Down Expand Up @@ -451,6 +468,70 @@ describe('Upgrader', () => {
expect(connections[0].streams).to.have.lengthOf(0)
expect(connections[1].streams).to.have.lengthOf(0)
})

it('should allow skipping encryption, protection and muxing', async () => {
const localStreamMuxerFactorySpy = sinon.spy(localMuxerFactory, 'createStreamMuxer')
const localMuxerFactoryOverride = mockMuxer()
const localStreamMuxerFactoryOverrideSpy = sinon.spy(localMuxerFactoryOverride, 'createStreamMuxer')
const localConnectionEncrypterSpy = sinon.spy(localConnectionEncrypter, 'secureOutbound')

const remoteStreamMuxerFactorySpy = sinon.spy(remoteMuxerFactory, 'createStreamMuxer')
const remoteMuxerFactoryOverride = mockMuxer()
const remoteStreamMuxerFactoryOverrideSpy = sinon.spy(remoteMuxerFactoryOverride, 'createStreamMuxer')
const remoteConnectionEncrypterSpy = sinon.spy(remoteConnectionEncrypter, 'secureInbound')

const { inbound, outbound } = mockMultiaddrConnPair({
addrs: [
multiaddr('/ip4/127.0.0.1/tcp/0').encapsulate(`/p2p/${remotePeer.toString()}`),
multiaddr('/ip4/127.0.0.1/tcp/0')
],
remotePeer
})

const connections = await Promise.all([
localUpgrader.upgradeOutbound(outbound, {
skipEncryption: true,
skipProtection: true,
muxerFactory: localMuxerFactoryOverride
}),
remoteUpgrader.upgradeInbound(inbound, {
skipEncryption: true,
skipProtection: true,
muxerFactory: remoteMuxerFactoryOverride
})
])

expect(connections).to.have.length(2)

const stream = await connections[0].newStream('/echo/1.0.0')
expect(stream).to.have.nested.property('stat.protocol', '/echo/1.0.0')

const hello = uint8ArrayFromString('hello there!')
const result = await pipe(
[hello],
stream,
function toBuffer (source) {
return (async function * () {
for await (const val of source) yield val.slice()
})()
},
async (source) => await all(source)
)

expect(result).to.eql([hello])

expect(localStreamMuxerFactorySpy.callCount).to.equal(0, 'did not use passed stream muxer factory')
expect(localStreamMuxerFactoryOverrideSpy.callCount).to.equal(1, 'did not use passed stream muxer factory')

expect(remoteStreamMuxerFactorySpy.callCount).to.equal(0, 'did not use passed stream muxer factory')
expect(remoteStreamMuxerFactoryOverrideSpy.callCount).to.equal(1, 'did not use passed stream muxer factory')

expect(localConnectionEncrypterSpy.callCount).to.equal(0, 'used local connection encrypter')
expect(remoteConnectionEncrypterSpy.callCount).to.equal(0, 'used remote connection encrypter')

expect(localConnectionProtector.protect.callCount).to.equal(0, 'used local connection protector')
expect(remoteConnectionProtector.protect.callCount).to.equal(0, 'used remote connection protector')
})
})

describe('libp2p.upgrader', () => {
Expand Down